diff --git a/api.py b/api.py index aaf1545..94a49d5 100644 --- a/api.py +++ b/api.py @@ -97,9 +97,9 @@ async def upload_file( files: Annotated[ List[UploadFile], File(description="Multiple files as UploadFile") ], - local_doc_id: str = Form(..., description="Local document ID", example="doc_id_1"), + knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), ): - saved_path = get_folder_path(local_doc_id) + saved_path = get_folder_path(knowledge_base_id) if not os.path.exists(saved_path): os.makedirs(saved_path) for file in files: @@ -107,17 +107,17 @@ async def upload_file( with open(file_path, "wb") as f: f.write(file.file.read()) - local_doc_qa.init_knowledge_vector_store(saved_path, get_vs_path(local_doc_id)) + local_doc_qa.init_knowledge_vector_store(saved_path, get_vs_path(knowledge_base_id)) return BaseResponse() async def list_docs( - local_doc_id: Optional[str] = Query(description="Document ID", example="doc_id1") + knowledge_base_id: Optional[str] = Query(description="Knowledge Base Name", example="kb1") ): - if local_doc_id: - local_doc_folder = get_folder_path(local_doc_id) + if knowledge_base_id: + local_doc_folder = get_folder_path(knowledge_base_id) if not os.path.exists(local_doc_folder): - return {"code": 1, "msg": f"document {local_doc_id} not found"} + return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} all_doc_names = [ doc for doc in os.listdir(local_doc_folder) @@ -138,34 +138,34 @@ async def list_docs( async def delete_docs( - local_doc_id: str = Form(..., description="local doc id", example="doc_id_1"), + knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), doc_name: Optional[str] = Form( None, description="doc name", example="doc_name_1.pdf" ), ): - if not os.path.exists(os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id)): - return {"code": 1, "msg": f"document {local_doc_id} not found"} + if not os.path.exists(os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id)): + return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} if doc_name: - doc_path = get_file_path(local_doc_id, doc_name) + doc_path = get_file_path(knowledge_base_id, doc_name) if os.path.exists(doc_path): os.remove(doc_path) else: return {"code": 1, "msg": f"document {doc_name} not found"} - remain_docs = await list_docs(local_doc_id) + remain_docs = await list_docs(knowledge_base_id) if remain_docs["code"] != 0 or len(remain_docs["data"]) == 0: - shutil.rmtree(get_folder_path(local_doc_id), ignore_errors=True) + shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True) else: local_doc_qa.init_knowledge_vector_store( - get_folder_path(local_doc_id), get_vs_path(local_doc_id) + get_folder_path(knowledge_base_id), get_vs_path(knowledge_base_id) ) else: - shutil.rmtree(get_folder_path(local_doc_id)) + shutil.rmtree(get_folder_path(knowledge_base_id)) return BaseResponse() async def chat( - local_doc_id: str = Body(..., description="Document ID", example="doc_id1"), + knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), question: str = Body(..., description="Question", example="工伤保险是什么?"), history: List[List[str]] = Body( [], @@ -178,9 +178,9 @@ async def chat( ], ), ): - vs_path = os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store") + vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store") if not os.path.exists(vs_path): - raise ValueError(f"Document {local_doc_id} not found") + raise ValueError(f"Knowledge base {knowledge_base_id} not found") for resp, history in local_doc_qa.get_knowledge_based_answer( query=question, vs_path=vs_path, chat_history=history, streaming=True @@ -200,12 +200,12 @@ async def chat( ) -async def stream_chat(websocket: WebSocket, local_doc_id: str): +async def stream_chat(websocket: WebSocket, knowledge_base_id: str): await websocket.accept() - vs_path = os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store") + vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store") if not os.path.exists(vs_path): - await websocket.send_json({"error": f"document {local_doc_id} not found"}) + await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"}) await websocket.close() return @@ -288,7 +288,7 @@ def main(): args = parser.parse_args() app = FastAPI() - app.websocket("/chat-docs/stream-chat/{local_doc_id}")(stream_chat) + app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat) app.post("/chat-docs/chat", response_model=ChatMessage)(chat) app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file) app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 6b3d1e2..b896607 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -184,7 +184,8 @@ class LocalDocQA: torch_gc(DEVICE) else: if not vs_path: - vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" + vs_path = os.path.join(VS_ROOT_PATH, + f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""") vector_store = FAISS.from_documents(docs, self.embeddings) torch_gc(DEVICE) diff --git a/configs/model_config.py b/configs/model_config.py index f3af5aa..bda83a2 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -36,9 +36,9 @@ USE_PTUNING_V2 = False # LLM running device LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" -VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store", "") +VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store") -UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "") +UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content") API_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "api_content") diff --git a/utils/__init__.py b/utils/__init__.py index 6cde439..4499cf3 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -7,7 +7,8 @@ def torch_gc(DEVICE): torch.cuda.ipc_collect() elif torch.backends.mps.is_available(): try: - torch.mps.empty_cache() + from torch.mps import empty_cache + empty_cache() except Exception as e: print(e) print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") \ No newline at end of file diff --git a/webui.py b/webui.py index 6c2a29c..2531059 100644 --- a/webui.py +++ b/webui.py @@ -95,12 +95,12 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to def get_vector_store(vs_id, files, history): - vs_path = VS_ROOT_PATH + vs_id + vs_path = os.path.join(VS_ROOT_PATH, vs_id) filelist = [] for file in files: filename = os.path.split(file.name)[-1] - shutil.move(file.name, UPLOAD_ROOT_PATH + filename) - filelist.append(UPLOAD_ROOT_PATH + filename) + shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, filename)) + filelist.append(os.path.join(UPLOAD_ROOT_PATH, filename)) if local_doc_qa.llm and local_doc_qa.embeddings: vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path) if len(loaded_files): @@ -118,7 +118,7 @@ def change_vs_name_input(vs_id): if vs_id == "新建知识库": return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None else: - return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), VS_ROOT_PATH + vs_id + return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH, vs_id) def change_mode(mode):