diff --git a/api.py b/api.py index edb6754..6f09ae5 100644 --- a/api.py +++ b/api.py @@ -13,11 +13,10 @@ from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket from fastapi.openapi.utils import get_openapi from pydantic import BaseModel from typing_extensions import Annotated - +from starlette.responses import RedirectResponse from chains.local_doc_qa import LocalDocQA -from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE, - EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH, - VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN) +from configs.model_config import (VS_ROOT_PATH, EMBEDDING_DEVICE, EMBEDDING_MODEL, LLM_MODEL, UPLOAD_ROOT_PATH, + NLTK_DATA_PATH, VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN) nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -76,37 +75,47 @@ class ChatMessage(BaseModel): def get_folder_path(local_doc_id: str): - return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id) + return os.path.join(UPLOAD_ROOT_PATH, local_doc_id) def get_vs_path(local_doc_id: str): - return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store") + return os.path.join(VS_ROOT_PATH, local_doc_id) def get_file_path(local_doc_id: str, doc_name: str): - return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, doc_name) + return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name) async def upload_file( - files: Annotated[ - List[UploadFile], File(description="Multiple files as UploadFile") - ], - knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), + files: Annotated[ + List[UploadFile], File(description="Multiple files as UploadFile") + ], + knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), ): saved_path = get_folder_path(knowledge_base_id) if not os.path.exists(saved_path): os.makedirs(saved_path) + filelist = [] for file in files: + file_content = '' file_path = os.path.join(saved_path, file.filename) - with open(file_path, "wb") as f: - f.write(file.file.read()) - - local_doc_qa.init_knowledge_vector_store(saved_path, get_vs_path(knowledge_base_id)) - return BaseResponse() + file_content = file.file.read() + if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content): + continue + with open(file_path, "ab+") as f: + f.write(file_content) + filelist.append(file_path) + if filelist: + vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id)) + if len(loaded_files): + file_status = f"已上传 {'、'.join([os.path.split(i)[-1] for i in loaded_files])} 至知识库,并已加载知识库,请开始提问" + return BaseResponse(code=200, msg=file_status) + file_status = "文件未成功加载,请重新上传文件" + return BaseResponse(code=500, msg=file_status) async def list_docs( - knowledge_base_id: Optional[str] = Query(description="Knowledge Base Name", example="kb1") + knowledge_base_id: Optional[str] = Query(description="Knowledge Base Name", example="kb1") ): if knowledge_base_id: local_doc_folder = get_folder_path(knowledge_base_id) @@ -119,25 +128,27 @@ async def list_docs( ] return ListDocsResponse(data=all_doc_names) else: - if not os.path.exists(API_UPLOAD_ROOT_PATH): + if not os.path.exists(UPLOAD_ROOT_PATH): all_doc_ids = [] else: all_doc_ids = [ folder - for folder in os.listdir(API_UPLOAD_ROOT_PATH) - if os.path.isdir(os.path.join(API_UPLOAD_ROOT_PATH, folder)) + for folder in os.listdir(UPLOAD_ROOT_PATH) + if os.path.isdir(os.path.join(UPLOAD_ROOT_PATH, folder)) ] return ListDocsResponse(data=all_doc_ids) async def delete_docs( - knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), - doc_name: Optional[str] = Form( - None, description="doc name", example="doc_name_1.pdf" - ), + knowledge_base_id: str = Form(..., + description="Knowledge Base Name(注意此方法仅删除上传的文件并不会删除知识库(FAISS)内数据)", + example="kb1"), + doc_name: Optional[str] = Form( + None, description="doc name", example="doc_name_1.pdf" + ), ): - if not os.path.exists(os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id)): + if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, knowledge_base_id)): return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} if doc_name: doc_path = get_file_path(knowledge_base_id, doc_name) @@ -159,25 +170,25 @@ async def delete_docs( async def chat( - knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), - question: str = Body(..., description="Question", example="工伤保险是什么?"), - history: List[List[str]] = Body( - [], - description="History of previous questions and answers", - example=[ - [ - "工伤保险是什么?", - "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。", - ] - ], - ), + knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), + question: str = Body(..., description="Question", example="工伤保险是什么?"), + history: List[List[str]] = Body( + [], + description="History of previous questions and answers", + example=[ + [ + "工伤保险是什么?", + "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。", + ] + ], + ), ): - vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store") + vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) if not os.path.exists(vs_path): raise ValueError(f"Knowledge base {knowledge_base_id} not found") for resp, history in local_doc_qa.get_knowledge_based_answer( - query=question, vs_path=vs_path, chat_history=history, streaming=True + query=question, vs_path=vs_path, chat_history=history, streaming=True ): pass source_documents = [ @@ -196,7 +207,7 @@ async def chat( async def stream_chat(websocket: WebSocket, knowledge_base_id: str): await websocket.accept() - vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store") + vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) if not os.path.exists(vs_path): await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"}) @@ -211,7 +222,7 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str): last_print_len = 0 for resp, history in local_doc_qa.get_knowledge_based_answer( - query=question, vs_path=vs_path, chat_history=history, streaming=True + query=question, vs_path=vs_path, chat_history=history, streaming=True ): await websocket.send_text(resp["result"][last_print_len:]) last_print_len = len(resp["result"]) @@ -236,40 +247,8 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str): turn += 1 -def gen_docs(): - global app - with tempfile.NamedTemporaryFile("w", encoding="utf-8", suffix=".json") as f: - json.dump( - get_openapi( - title=app.title, - version=app.version, - openapi_version=app.openapi_version, - description=app.description, - routes=app.routes, - ), - f, - ensure_ascii=False, - ) - f.flush() - # test whether widdershins is available - try: - subprocess.run( - [ - "widdershins", - f.name, - "-o", - os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "docs", - "API.md", - ), - ], - check=True, - ) - except Exception: - raise RuntimeError( - "Failed to generate docs. Please install widdershins first." - ) +async def document(): + return RedirectResponse(url="/docs") def main(): @@ -278,7 +257,6 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7861) - parser.add_argument("--gen-docs", action="store_true") args = parser.parse_args() app = FastAPI() @@ -287,10 +265,7 @@ def main(): app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file) app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs) app.delete("/chat-docs/delete", response_model=BaseResponse)(delete_docs) - - if args.gen_docs: - gen_docs() - return + app.get("/", response_model=BaseResponse)(document) local_doc_qa = LocalDocQA() local_doc_qa.init_cfg( diff --git a/configs/model_config.py b/configs/model_config.py index cb6c5f8..859273b 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -28,7 +28,6 @@ llm_model_dict = { LLM_MODEL = "chatglm-6b" # LLM lora path,默认为空,如果有请直接指定文件夹路径 -# 推荐使用 chatglm-6b-belle-zh-lora LLM_LORA_PATH = "" USE_LORA = True if LLM_LORA_PATH else False @@ -45,8 +44,6 @@ VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_ UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content") -API_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "api_content") - # 基于上下文的prompt模版,请务必保留"{question}"和"{context}" PROMPT_TEMPLATE = """已知信息: {context} @@ -62,4 +59,4 @@ LLM_HISTORY_LEN = 3 # return top-k text chunk from vector store VECTOR_SEARCH_TOP_K = 5 -NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") \ No newline at end of file +NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") diff --git a/webui.py b/webui.py index 308e95c..b1236ea 100644 --- a/webui.py +++ b/webui.py @@ -48,12 +48,6 @@ def get_answer(query, vs_path, history, mode, yield history, "" -def update_status(history, status): - history = history + [[None, status]] - print(status) - return history - - def init_model(): try: local_doc_qa.init_cfg() @@ -92,10 +86,12 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, us def get_vector_store(vs_id, files, history): vs_path = os.path.join(VS_ROOT_PATH, vs_id) filelist = [] + if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)): + os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id)) for file in files: filename = os.path.split(file.name)[-1] - shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, filename)) - filelist.append(os.path.join(UPLOAD_ROOT_PATH, filename)) + shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, vs_id, filename)) + filelist.append(os.path.join(UPLOAD_ROOT_PATH, vs_id, filename)) if local_doc_qa.llm and local_doc_qa.embeddings: vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path) if len(loaded_files):