398 lines
16 KiB
Python
398 lines
16 KiB
Python
import argparse
|
||
import json
|
||
import os
|
||
import shutil
|
||
from typing import List, Optional
|
||
import urllib
|
||
|
||
import nltk
|
||
import pydantic
|
||
import uvicorn
|
||
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
from typing_extensions import Annotated
|
||
from starlette.responses import RedirectResponse
|
||
|
||
from chains.local_doc_qa import LocalDocQA
|
||
from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
|
||
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
|
||
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
|
||
from agent import bing_search as agent_bing_search
|
||
import models.shared as shared
|
||
from models.loader.args import parser
|
||
from models.loader import LoaderCheckPoint
|
||
|
||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||
|
||
|
||
class BaseResponse(BaseModel):
|
||
code: int = pydantic.Field(200, description="HTTP status code")
|
||
msg: str = pydantic.Field("success", description="HTTP status message")
|
||
|
||
class Config:
|
||
schema_extra = {
|
||
"example": {
|
||
"code": 200,
|
||
"msg": "success",
|
||
}
|
||
}
|
||
|
||
|
||
class ListDocsResponse(BaseResponse):
|
||
data: List[str] = pydantic.Field(..., description="List of document names")
|
||
|
||
class Config:
|
||
schema_extra = {
|
||
"example": {
|
||
"code": 200,
|
||
"msg": "success",
|
||
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
|
||
}
|
||
}
|
||
|
||
|
||
class ChatMessage(BaseModel):
|
||
question: str = pydantic.Field(..., description="Question text")
|
||
response: str = pydantic.Field(..., description="Response text")
|
||
history: List[List[str]] = pydantic.Field(..., description="History text")
|
||
source_documents: List[str] = pydantic.Field(
|
||
..., description="List of source documents and their scores"
|
||
)
|
||
|
||
class Config:
|
||
schema_extra = {
|
||
"example": {
|
||
"question": "工伤ä¿<EFBFBD>险如何办ç<EFBFBD>†ï¼Ÿ",
|
||
"response": "æ ¹æ<EFBFBD>®å·²çŸ¥ä¿¡æ<EFBFBD>¯ï¼Œå<EFBFBD>¯ä»¥æ€»ç»“如下:\n\n1. å<>‚ä¿<C3A4>å<EFBFBD>•ä½<C3A4>为员工缴纳工伤ä¿<C3A4>险费,以ä¿<C3A4>障员工在å<C2A8>‘生工伤时能够获得相应的待é<E280A6>‡ã€‚\n2. ä¸<C3A4>å<EFBFBD>Œåœ°åŒºçš„工伤ä¿<C3A4>险缴费规定å<C5A1>¯èƒ½æœ‰æ‰€ä¸<C3A4>å<EFBFBD>Œï¼Œéœ€è¦<C3A8>å<EFBFBD>‘当地社ä¿<C3A4>éƒ¨é—¨å’¨è¯¢ä»¥äº†è§£å…·ä½“çš„ç¼´è´¹æ ‡å‡†å’Œè§„å®šã€‚\n3. 工伤从业人员å<CB9C>Šå…¶è¿‘亲属需è¦<C3A8>申请工伤认定,确认享å<C2AB>—的待é<E280A6>‡èµ„æ ¼ï¼Œå¹¶æŒ‰æ—¶ç¼´çº³å·¥ä¼¤ä¿<C3A4>险费。\n4. 工伤ä¿<C3A4>险待é<E280A6>‡åŒ…括工伤医疗ã€<C3A3>康å¤<C3A5>ã€<C3A3>辅助器具é…<C3A9>置费用ã€<C3A3>伤残待é<E280A6>‡ã€<C3A3>工亡待é<E280A6>‡ã€<C3A3>一次性工亡补助金ç‰ã€‚\n5. 工伤ä¿<C3A4>险待é<E280A6>‡é¢†å<E280A0>–èµ„æ ¼è®¤è¯<C3A8>包括长期待é<E280A6>‡é¢†å<E280A0>–人员认è¯<C3A8>和一次性待é<E280A6>‡é¢†å<E280A0>–人员认è¯<C3A8>。\n6. 工伤ä¿<C3A4>险基金支付的待é<E280A6>‡é¡¹ç›®åŒ…括工伤医疗待é<E280A6>‡ã€<C3A3>康å¤<C3A5>å¾…é<E280A6>‡ã€<C3A3>辅助器具é…<C3A9>置费用ã€<C3A3>一次性工亡补助金ã€<C3A3>丧葬补助金ç‰ã€‚",
|
||
"history": [
|
||
[
|
||
"工伤ä¿<EFBFBD>险是什么?",
|
||
"工伤ä¿<EFBFBD>险是指用人å<EFBFBD>•ä½<EFBFBD>按照国家规定,为本å<EFBFBD>•ä½<EFBFBD>çš„è<EFBFBD>Œå·¥å’Œç”¨äººå<EFBFBD>•ä½<EFBFBD>的其他人员,缴纳工伤ä¿<EFBFBD>险费,由ä¿<EFBFBD>é™©æœºæž„æŒ‰ç…§å›½å®¶è§„å®šçš„æ ‡å‡†ï¼Œç»™äºˆå·¥ä¼¤ä¿<EFBFBD>险待é<EFBFBD>‡çš„社会ä¿<EFBFBD>险制度。",
|
||
]
|
||
],
|
||
"source_documents": [
|
||
"出处 [1] 广州市å<E2809A>•ä½<C3A4>从业的特定人员å<CB9C>‚åŠ å·¥ä¼¤ä¿<C3A4>险办事指引.docx:\n\n\t( 一) 从业å<C5A1>•ä½<C3A4> (组织) 按“自愿å<C2BF>‚ä¿<C3A4>â€<C3A2>原则, 为未建 立劳动关系的特定从业人员å<CB9C>•项å<C2B9>‚åŠ å·¥ä¼¤ä¿<C3A4>险 ã€<C3A3>缴纳工伤ä¿<C3A4> 险费。",
|
||
"出处 [2] ...",
|
||
"出处 [3] ...",
|
||
],
|
||
}
|
||
}
|
||
|
||
|
||
def get_folder_path(local_doc_id: str):
|
||
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id)
|
||
|
||
|
||
def get_vs_path(local_doc_id: str):
|
||
return os.path.join(VS_ROOT_PATH, local_doc_id)
|
||
|
||
|
||
def get_file_path(local_doc_id: str, doc_name: str):
|
||
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
|
||
|
||
|
||
async def upload_file(
|
||
file: UploadFile = File(description="A single binary file"),
|
||
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
||
):
|
||
saved_path = get_folder_path(knowledge_base_id)
|
||
if not os.path.exists(saved_path):
|
||
os.makedirs(saved_path)
|
||
|
||
file_content = await file.read() # 读å<C2BB>–ä¸Šä¼ æ–‡ä»¶çš„å†…å®¹
|
||
|
||
file_path = os.path.join(saved_path, file.filename)
|
||
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
|
||
file_status = f"文件 {file.filename} å·²å˜åœ¨ã€‚"
|
||
return BaseResponse(code=200, msg=file_status)
|
||
|
||
with open(file_path, "wb") as f:
|
||
f.write(file_content)
|
||
|
||
vs_path = get_vs_path(knowledge_base_id)
|
||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store([file_path], vs_path)
|
||
if len(loaded_files) > 0:
|
||
file_status = f"文件 {file.filename} å·²ä¸Šä¼ è‡³æ–°çš„çŸ¥è¯†åº“ï¼Œå¹¶å·²åŠ è½½çŸ¥è¯†åº“ï¼Œè¯·å¼€å§‹æ<E280B9><C3A6>问。"
|
||
return BaseResponse(code=200, msg=file_status)
|
||
else:
|
||
file_status = "æ–‡ä»¶ä¸Šä¼ å¤±è´¥ï¼Œè¯·é‡<EFBFBD>æ–°ä¸Šä¼ "
|
||
return BaseResponse(code=500, msg=file_status)
|
||
|
||
|
||
async def upload_files(
|
||
files: Annotated[
|
||
List[UploadFile], File(description="Multiple files as UploadFile")
|
||
],
|
||
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
||
):
|
||
saved_path = get_folder_path(knowledge_base_id)
|
||
if not os.path.exists(saved_path):
|
||
os.makedirs(saved_path)
|
||
filelist = []
|
||
for file in files:
|
||
file_content = ''
|
||
file_path = os.path.join(saved_path, file.filename)
|
||
file_content = file.file.read()
|
||
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
|
||
continue
|
||
with open(file_path, "ab+") as f:
|
||
f.write(file_content)
|
||
filelist.append(file_path)
|
||
if filelist:
|
||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id))
|
||
if len(loaded_files):
|
||
file_status = f"å·²ä¸Šä¼ {'ã€<EFBFBD>'.join([os.path.split(i)[-1] for i in loaded_files])} è‡³çŸ¥è¯†åº“ï¼Œå¹¶å·²åŠ è½½çŸ¥è¯†åº“ï¼Œè¯·å¼€å§‹æ<E280B9><C3A6>é—®"
|
||
return BaseResponse(code=200, msg=file_status)
|
||
file_status = "文件未æˆ<EFBFBD>åŠŸåŠ è½½ï¼Œè¯·é‡<EFBFBD>æ–°ä¸Šä¼ æ–‡ä»¶"
|
||
return BaseResponse(code=500, msg=file_status)
|
||
|
||
|
||
async def list_docs(
|
||
knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1")
|
||
):
|
||
if knowledge_base_id:
|
||
local_doc_folder = get_folder_path(knowledge_base_id)
|
||
if not os.path.exists(local_doc_folder):
|
||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
||
all_doc_names = [
|
||
doc
|
||
for doc in os.listdir(local_doc_folder)
|
||
if os.path.isfile(os.path.join(local_doc_folder, doc))
|
||
]
|
||
return ListDocsResponse(data=all_doc_names)
|
||
else:
|
||
if not os.path.exists(UPLOAD_ROOT_PATH):
|
||
all_doc_ids = []
|
||
else:
|
||
all_doc_ids = [
|
||
folder
|
||
for folder in os.listdir(UPLOAD_ROOT_PATH)
|
||
if os.path.isdir(os.path.join(UPLOAD_ROOT_PATH, folder))
|
||
]
|
||
|
||
return ListDocsResponse(data=all_doc_ids)
|
||
|
||
|
||
async def delete_docs(
|
||
knowledge_base_id: str = Query(...,
|
||
description="Knowledge Base Name(注æ„<C3A6>æ¤æ–¹æ³•ä»…åˆ é™¤ä¸Šä¼ çš„æ–‡ä»¶å¹¶ä¸<C3A4>ä¼šåˆ é™¤çŸ¥è¯†åº“(FAISS)内数æ<C2B0>®)",
|
||
example="kb1"),
|
||
doc_name: Optional[str] = Query(
|
||
None, description="doc name", example="doc_name_1.pdf"
|
||
),
|
||
):
|
||
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, knowledge_base_id)):
|
||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
||
if doc_name:
|
||
doc_path = get_file_path(knowledge_base_id, doc_name)
|
||
if os.path.exists(doc_path):
|
||
os.remove(doc_path)
|
||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||
else:
|
||
return BaseResponse(code=1, msg=f"document {doc_name} not found")
|
||
|
||
remain_docs = await list_docs(knowledge_base_id)
|
||
remain_docs = remain_docs.json()
|
||
if len(remain_docs["data"]) == 0:
|
||
shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
|
||
else:
|
||
local_doc_qa.init_knowledge_vector_store(
|
||
get_folder_path(knowledge_base_id), get_vs_path(knowledge_base_id)
|
||
)
|
||
else:
|
||
shutil.rmtree(get_folder_path(knowledge_base_id))
|
||
return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success")
|
||
|
||
|
||
async def local_doc_chat(
|
||
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
||
question: str = Body(..., description="Question", example="工伤ä¿<EFBFBD>险是什么?"),
|
||
history: List[List[str]] = Body(
|
||
[],
|
||
description="History of previous questions and answers",
|
||
example=[
|
||
[
|
||
"工伤ä¿<EFBFBD>险是什么?",
|
||
"工伤ä¿<EFBFBD>险是指用人å<EFBFBD>•ä½<EFBFBD>按照国家规定,为本å<EFBFBD>•ä½<EFBFBD>çš„è<EFBFBD>Œå·¥å’Œç”¨äººå<EFBFBD>•ä½<EFBFBD>的其他人员,缴纳工伤ä¿<EFBFBD>险费,由ä¿<EFBFBD>é™©æœºæž„æŒ‰ç…§å›½å®¶è§„å®šçš„æ ‡å‡†ï¼Œç»™äºˆå·¥ä¼¤ä¿<EFBFBD>险待é<EFBFBD>‡çš„社会ä¿<EFBFBD>险制度。",
|
||
]
|
||
],
|
||
),
|
||
):
|
||
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
|
||
if not os.path.exists(vs_path):
|
||
# return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found")
|
||
return ChatMessage(
|
||
question=question,
|
||
response=f"Knowledge base {knowledge_base_id} not found",
|
||
history=history,
|
||
source_documents=[],
|
||
)
|
||
else:
|
||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||
):
|
||
pass
|
||
source_documents = [
|
||
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||
f"""相关度:{doc.metadata['score']}\n\n"""
|
||
for inum, doc in enumerate(resp["source_documents"])
|
||
]
|
||
|
||
return ChatMessage(
|
||
question=question,
|
||
response=resp["result"],
|
||
history=history,
|
||
source_documents=source_documents,
|
||
)
|
||
|
||
|
||
async def chat(
|
||
question: str = Body(..., description="Question", example="工伤ä¿<EFBFBD>险是什么?"),
|
||
history: List[List[str]] = Body(
|
||
[],
|
||
description="History of previous questions and answers",
|
||
example=[
|
||
[
|
||
"工伤ä¿<EFBFBD>险是什么?",
|
||
"工伤ä¿<EFBFBD>险是指用人å<EFBFBD>•ä½<EFBFBD>按照国家规定,为本å<EFBFBD>•ä½<EFBFBD>çš„è<EFBFBD>Œå·¥å’Œç”¨äººå<EFBFBD>•ä½<EFBFBD>的其他人员,缴纳工伤ä¿<EFBFBD>险费,由ä¿<EFBFBD>é™©æœºæž„æŒ‰ç…§å›½å®¶è§„å®šçš„æ ‡å‡†ï¼Œç»™äºˆå·¥ä¼¤ä¿<EFBFBD>险待é<EFBFBD>‡çš„社会ä¿<EFBFBD>险制度。",
|
||
]
|
||
],
|
||
),
|
||
):
|
||
|
||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
||
streaming=True):
|
||
|
||
resp = answer_result.llm_output["answer"]
|
||
history = answer_result.history
|
||
pass
|
||
|
||
return ChatMessage(
|
||
question=question,
|
||
response=resp,
|
||
history=history,
|
||
source_documents=[],
|
||
)
|
||
|
||
|
||
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
||
await websocket.accept()
|
||
turn = 1
|
||
while True:
|
||
input_json = await websocket.receive_json()
|
||
question, history, knowledge_base_id = input_json[""], input_json["history"], input_json["knowledge_base_id"]
|
||
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
|
||
|
||
if not os.path.exists(vs_path):
|
||
await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"})
|
||
await websocket.close()
|
||
return
|
||
|
||
await websocket.send_json({"question": question, "turn": turn, "flag": "start"})
|
||
|
||
last_print_len = 0
|
||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||
):
|
||
await websocket.send_text(resp["result"][last_print_len:])
|
||
last_print_len = len(resp["result"])
|
||
|
||
source_documents = [
|
||
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||
f"""相关度:{doc.metadata['score']}\n\n"""
|
||
for inum, doc in enumerate(resp["source_documents"])
|
||
]
|
||
|
||
await websocket.send_text(
|
||
json.dumps(
|
||
{
|
||
"question": question,
|
||
"turn": turn,
|
||
"flag": "end",
|
||
"sources_documents": source_documents,
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
)
|
||
turn += 1
|
||
|
||
|
||
async def document():
|
||
return RedirectResponse(url="/docs")
|
||
|
||
|
||
async def bing_search(
|
||
search_text: str = Query(default=None, description="text you want to search", example="langchain")
|
||
):
|
||
results = agent_bing_search(search_text)
|
||
result_str = ''
|
||
for result in results:
|
||
for k, v in result.items():
|
||
result_str += "%s: %s\n" % (k, v)
|
||
result_str += '\n'
|
||
|
||
return ChatMessage(
|
||
question=search_text,
|
||
response=result_str,
|
||
history=[],
|
||
source_documents=[],
|
||
)
|
||
|
||
|
||
def api_start(host, port):
|
||
global app
|
||
global local_doc_qa
|
||
|
||
llm_model_ins = shared.loaderLLM()
|
||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
||
|
||
app = FastAPI()
|
||
# Add CORS middleware to allow all origins
|
||
# 在config.pyä¸è®¾ç½®OPEN_DOMAIN=True,å…<C3A5>许跨域
|
||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||
if OPEN_CROSS_DOMAIN:
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat)
|
||
|
||
app.get("/", response_model=BaseResponse)(document)
|
||
|
||
app.post("/chat", response_model=ChatMessage)(chat)
|
||
|
||
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
|
||
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
|
||
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
|
||
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
|
||
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
|
||
|
||
app.get("/bing_search", response_model=ChatMessage)(bing_search)
|
||
|
||
local_doc_qa = LocalDocQA()
|
||
local_doc_qa.init_cfg(
|
||
llm_model=llm_model_ins,
|
||
embedding_model=EMBEDDING_MODEL,
|
||
embedding_device=EMBEDDING_DEVICE,
|
||
top_k=VECTOR_SEARCH_TOP_K,
|
||
)
|
||
uvicorn.run(app, host=host, port=port)
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
|
||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||
parser.add_argument("--port", type=int, default=7861)
|
||
# åˆ<C3A5>始化消æ<CB86>¯
|
||
args = None
|
||
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])
|
||
args_dict = vars(args)
|
||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||
api_start(args.host, args.port)
|