Langchain-Chatchat/api.py

398 lines
16 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import json
import os
import shutil
from typing import List, Optional
import urllib
import nltk
import pydantic
import uvicorn
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing_extensions import Annotated
from starlette.responses import RedirectResponse
from chains.local_doc_qa import LocalDocQA
from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
from agent import bing_search as agent_bing_search
import models.shared as shared
from models.loader.args import parser
from models.loader import LoaderCheckPoint
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field("success", description="HTTP status message")
class Config:
schema_extra = {
"example": {
"code": 200,
"msg": "success",
}
}
class ListDocsResponse(BaseResponse):
data: List[str] = pydantic.Field(..., description="List of document names")
class Config:
schema_extra = {
"example": {
"code": 200,
"msg": "success",
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
}
}
class ChatMessage(BaseModel):
question: str = pydantic.Field(..., description="Question text")
response: str = pydantic.Field(..., description="Response text")
history: List[List[str]] = pydantic.Field(..., description="History text")
source_documents: List[str] = pydantic.Field(
..., description="List of source documents and their scores"
)
class Config:
schema_extra = {
"example": {
"question": "工伤ä¿<EFBFBD>险å¦ä½•办ç<EFBFBD>†ï¼Ÿ",
"response": "æ ¹æ<EFBFBD>®å·²çŸ¥ä¿¡æ<EFBFBD>¯ï¼Œå<EFBFBD>¯ä»¥æ€»ç»“å¦ä¸ï¼š\n\n1. å<>ä¿<C3A4>å<EFBFBD>•ä½<C3A4>为å˜å·¥ç¼´çº³å·¥ä¼¤ä¿<C3A4>险费,以ä¿<C3A4>éšœå˜å·¥åœ¨å<C2A8>生工伤时能够获得ç¸åº”的待é<E280A6>‡ã€\n2. ä¸<C3A4>å<EFBFBD>Œåœ°åŒºçš„工伤ä¿<C3A4>险缴费规定å<C5A1>¯èƒ½æœ‰æ‰€ä¸<C3A4>å<EFBFBD>Œï¼Œéœ€è¦<C3A8>å<EFBFBD>当地社ä¿<C3A4>部门å¨è¯¢ä»¥äº†è§£å…·ä½“的缴费标准åŒè§„定ã€\n3. 工伤从业人å˜å<CB9C>Šå…¶è¿äº²å±žéœ€è¦<C3A8>申请工伤认定,确认享å<C2AB>—的待é<E280A6>‡èµ„格,并按时缴纳工伤ä¿<C3A4>险费ã€\n4. 工伤ä¿<C3A4>险待é<E280A6>‡åŒ…æ¬å·¥ä¼¤åŒ»ç—ã€<C3A3>康å¤<C3A5>ã€<C3A3>辅助器具é…<C3A9>置费用ã€<C3A3>伤æ®å¾…é<E280A6>‡ã€<C3A3>工亡待é<E280A6>‡ã€<C3A3>一次性工亡补助é‡ç­‰ã€\n5. 工伤ä¿<C3A4>险待é<E280A6>‡é¢†å<E280A0>资格认è¯<C3A8>包æ¬é•¿æœŸå¾…é<E280A6>‡é¢†å<E280A0>人å˜è®¤è¯<C3A8>åŒä¸€æ¬¡æ€§å¾…é<E280A6>‡é¢†å<E280A0>人å˜è®¤è¯<C3A8>ã€\n6. 工伤ä¿<C3A4>é™©åŸºé‡æ”¯ä»˜çš„å¾…é<E280A6>‡é¡¹ç®åŒ…æ¬å·¥ä¼¤åŒ»ç—å¾…é<E280A6>‡ã€<C3A3>康å¤<C3A5>å¾…é<E280A6>‡ã€<C3A3>辅助器具é…<C3A9>置费用ã€<C3A3>一次性工亡补助é‡ã€<C3A3>丧è¬è¡¥åŠ©é‡ç­‰ã€",
"history": [
[
"工伤ä¿<EFBFBD>险是什么?",
"工伤ä¿<EFBFBD>险是指用人å<EFBFBD>•ä½<EFBFBD>按照å½å®¶è§„定,为本å<EFBFBD>•ä½<EFBFBD>çš„è<EFBFBD>Œå·¥åŒç”¨äººå<EFBFBD>•ä½<EFBFBD>的其ä»äººå˜ï¼Œç¼´çº³å·¥ä¼¤ä¿<EFBFBD>险费,由ä¿<EFBFBD>险机构按照å½å®¶è§„定的标准,给予工伤ä¿<EFBFBD>险待é<EFBFBD>‡çš„社会ä¿<EFBFBD>险制度ã€",
]
],
"source_documents": [
"出处 [1] 广州å¸å<E2809A>•ä½<C3A4>从业的特定人å˜å<CB9C>加工伤ä¿<C3A4>é™©åŠžäºæŒ‡å¼•.docx:\n\n\t( 一) 从业å<C5A1>•ä½<C3A4> (组织) 按“自愿å<C2BF>ä¿<C3A4>â€<C3A2>原则, 为未建 ç«åŠ³åŠ¨å…³ç³»çš„ç‰¹å®šä»Žä¸šäººå˜å<CB9C>•项å<C2B9>加工伤ä¿<C3A4>险 ã€<C3A3>缴纳工伤ä¿<C3A4> 险费。",
"出处 [2] ...",
"出处 [3] ...",
],
}
}
def get_folder_path(local_doc_id: str):
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id)
def get_vs_path(local_doc_id: str):
return os.path.join(VS_ROOT_PATH, local_doc_id)
def get_file_path(local_doc_id: str, doc_name: str):
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
async def upload_file(
file: UploadFile = File(description="A single binary file"),
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
):
saved_path = get_folder_path(knowledge_base_id)
if not os.path.exists(saved_path):
os.makedirs(saved_path)
file_content = await file.read() # 读å<C2BB>上传æ‡ä»¶çš„内容
file_path = os.path.join(saved_path, file.filename)
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
file_status = f"文件 {file.filename} 已存在。"
return BaseResponse(code=200, msg=file_status)
with open(file_path, "wb") as f:
f.write(file_content)
vs_path = get_vs_path(knowledge_base_id)
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store([file_path], vs_path)
if len(loaded_files) > 0:
file_status = f"文件 {file.filename} 已上传至æ°çš„çŸ¥è¯†åº“ï¼Œå¹¶å·²åŠ è½½çŸ¥è¯†åº“ï¼Œè¯·å¼€å§æ<E280B9><C3A6>é—®ã€"
return BaseResponse(code=200, msg=file_status)
else:
file_status = "æ‡ä»¶ä¸Šä¼ å¤±è´¥ï¼Œè¯·é‡<EFBFBD>æ°ä¸Šä¼ "
return BaseResponse(code=500, msg=file_status)
async def upload_files(
files: Annotated[
List[UploadFile], File(description="Multiple files as UploadFile")
],
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
):
saved_path = get_folder_path(knowledge_base_id)
if not os.path.exists(saved_path):
os.makedirs(saved_path)
filelist = []
for file in files:
file_content = ''
file_path = os.path.join(saved_path, file.filename)
file_content = file.file.read()
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
continue
with open(file_path, "ab+") as f:
f.write(file_content)
filelist.append(file_path)
if filelist:
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id))
if len(loaded_files):
file_status = f"已上传 {'ã€<EFBFBD>'.join([os.path.split(i)[-1] for i in loaded_files])} è‡³çŸ¥è¯†åº“ï¼Œå¹¶å·²åŠ è½½çŸ¥è¯†åº“ï¼Œè¯·å¼€å§æ<E280B9><C3A6>é—®"
return BaseResponse(code=200, msg=file_status)
file_status = "æ‡ä»¶æœªæˆ<EFBFBD>功加载,请é‡<EFBFBD>æ°ä¸Šä¼ æ‡ä»¶"
return BaseResponse(code=500, msg=file_status)
async def list_docs(
knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1")
):
if knowledge_base_id:
local_doc_folder = get_folder_path(knowledge_base_id)
if not os.path.exists(local_doc_folder):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
all_doc_names = [
doc
for doc in os.listdir(local_doc_folder)
if os.path.isfile(os.path.join(local_doc_folder, doc))
]
return ListDocsResponse(data=all_doc_names)
else:
if not os.path.exists(UPLOAD_ROOT_PATH):
all_doc_ids = []
else:
all_doc_ids = [
folder
for folder in os.listdir(UPLOAD_ROOT_PATH)
if os.path.isdir(os.path.join(UPLOAD_ROOT_PATH, folder))
]
return ListDocsResponse(data=all_doc_ids)
async def delete_docs(
knowledge_base_id: str = Query(...,
description="Knowledge Base Name(注æ„<C3A6>æ­¤æ¹æ³•仅删除上传的æ‡ä»¶å¹¶ä¸<C3A4>会删除知识库(FAISS)内数æ<C2B0>®)",
example="kb1"),
doc_name: Optional[str] = Query(
None, description="doc name", example="doc_name_1.pdf"
),
):
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, knowledge_base_id)):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
if doc_name:
doc_path = get_file_path(knowledge_base_id, doc_name)
if os.path.exists(doc_path):
os.remove(doc_path)
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
else:
return BaseResponse(code=1, msg=f"document {doc_name} not found")
remain_docs = await list_docs(knowledge_base_id)
remain_docs = remain_docs.json()
if len(remain_docs["data"]) == 0:
shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
else:
local_doc_qa.init_knowledge_vector_store(
get_folder_path(knowledge_base_id), get_vs_path(knowledge_base_id)
)
else:
shutil.rmtree(get_folder_path(knowledge_base_id))
return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success")
async def local_doc_chat(
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
question: str = Body(..., description="Question", example="工伤ä¿<EFBFBD>险是什么?"),
history: List[List[str]] = Body(
[],
description="History of previous questions and answers",
example=[
[
"工伤ä¿<EFBFBD>险是什么?",
"工伤ä¿<EFBFBD>险是指用人å<EFBFBD>•ä½<EFBFBD>按照å½å®¶è§„定,为本å<EFBFBD>•ä½<EFBFBD>çš„è<EFBFBD>Œå·¥åŒç”¨äººå<EFBFBD>•ä½<EFBFBD>的其ä»äººå˜ï¼Œç¼´çº³å·¥ä¼¤ä¿<EFBFBD>险费,由ä¿<EFBFBD>险机构按照å½å®¶è§„定的标准,给予工伤ä¿<EFBFBD>险待é<EFBFBD>‡çš„社会ä¿<EFBFBD>险制度ã€",
]
],
),
):
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
if not os.path.exists(vs_path):
# return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found")
return ChatMessage(
question=question,
response=f"Knowledge base {knowledge_base_id} not found",
history=history,
source_documents=[],
)
else:
for resp, history in local_doc_qa.get_knowledge_based_answer(
query=question, vs_path=vs_path, chat_history=history, streaming=True
):
pass
source_documents = [
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in enumerate(resp["source_documents"])
]
return ChatMessage(
question=question,
response=resp["result"],
history=history,
source_documents=source_documents,
)
async def chat(
question: str = Body(..., description="Question", example="工伤ä¿<EFBFBD>险是什么?"),
history: List[List[str]] = Body(
[],
description="History of previous questions and answers",
example=[
[
"工伤ä¿<EFBFBD>险是什么?",
"工伤ä¿<EFBFBD>险是指用人å<EFBFBD>•ä½<EFBFBD>按照å½å®¶è§„定,为本å<EFBFBD>•ä½<EFBFBD>çš„è<EFBFBD>Œå·¥åŒç”¨äººå<EFBFBD>•ä½<EFBFBD>的其ä»äººå˜ï¼Œç¼´çº³å·¥ä¼¤ä¿<EFBFBD>险费,由ä¿<EFBFBD>险机构按照å½å®¶è§„定的标准,给予工伤ä¿<EFBFBD>险待é<EFBFBD>‡çš„社会ä¿<EFBFBD>险制度ã€",
]
],
),
):
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
streaming=True):
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
return ChatMessage(
question=question,
response=resp,
history=history,
source_documents=[],
)
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
await websocket.accept()
turn = 1
while True:
input_json = await websocket.receive_json()
question, history, knowledge_base_id = input_json[""], input_json["history"], input_json["knowledge_base_id"]
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
if not os.path.exists(vs_path):
await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"})
await websocket.close()
return
await websocket.send_json({"question": question, "turn": turn, "flag": "start"})
last_print_len = 0
for resp, history in local_doc_qa.get_knowledge_based_answer(
query=question, vs_path=vs_path, chat_history=history, streaming=True
):
await websocket.send_text(resp["result"][last_print_len:])
last_print_len = len(resp["result"])
source_documents = [
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in enumerate(resp["source_documents"])
]
await websocket.send_text(
json.dumps(
{
"question": question,
"turn": turn,
"flag": "end",
"sources_documents": source_documents,
},
ensure_ascii=False,
)
)
turn += 1
async def document():
return RedirectResponse(url="/docs")
async def bing_search(
search_text: str = Query(default=None, description="text you want to search", example="langchain")
):
results = agent_bing_search(search_text)
result_str = ''
for result in results:
for k, v in result.items():
result_str += "%s: %s\n" % (k, v)
result_str += '\n'
return ChatMessage(
question=search_text,
response=result_str,
history=[],
source_documents=[],
)
def api_start(host, port):
global app
global local_doc_qa
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
app = FastAPI()
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True,å…<C3A5>许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat)
app.get("/", response_model=BaseResponse)(document)
app.post("/chat", response_model=ChatMessage)(chat)
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
app.get("/bing_search", response_model=ChatMessage)(bing_search)
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(
llm_model=llm_model_ins,
embedding_model=EMBEDDING_MODEL,
embedding_device=EMBEDDING_DEVICE,
top_k=VECTOR_SEARCH_TOP_K,
)
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
# åˆ<C3A5>å§åŒæ¶ˆæ<CB86>¯
args = None
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
api_start(args.host, args.port)