diff --git a/requirements.txt b/requirements.txt index 3bef8f0..765cf9f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ uvicorn~=0.23.1 starlette~=0.27.0 numpy~=1.24.4 pydantic~=1.10.11 -unstructured[local-inference] +unstructured[all-docs] streamlit>=1.25.0 streamlit-option-menu diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index bc0ee64..749f7c2 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,14 +1,11 @@ import os import urllib import shutil -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from langchain.vectorstores import FAISS from fastapi import File, Form, UploadFile -from server.utils import BaseResponse, ListResponse, torch_gc +from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path, - get_vs_path, get_file_path, file2text) -from configs.model_config import embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE -from server.knowledge_base.utils import load_embeddings, refresh_vs_cache + get_file_path, file2text, docs2vs, + refresh_vs_cache, ) async def list_docs(knowledge_base_name: str): @@ -39,7 +36,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), saved_path = get_doc_path(knowledge_base_name) if not os.path.exists(saved_path): - return BaseResponse(code=404, msg="未找到知识库 {knowledge_base_name}") + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") file_content = await file.read() # 读取上传文件的内容 @@ -48,32 +45,17 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), file_status = f"文件 {file.filename} 已存在。" return BaseResponse(code=404, msg=file_status) - with open(file_path, "wb") as f: - f.write(file_content) + try: + with open(file_path, "wb") as f: + f.write(file_content) + except Exception as e: + return BaseResponse(code=500, msg=f"{file.filename} 文件上传失败,报错信息为: {e}") - vs_path = get_vs_path(knowledge_base_name) - # TODO: 重写知识库生成/添加逻辑 filepath = get_file_path(knowledge_base_name, file.filename) docs = file2text(filepath) - loaded_files = [file] - embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE) - if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): - vector_store = FAISS.load_local(vs_path, embeddings) - vector_store.add_documents(docs) - torch_gc() - else: - if not os.path.exists(vs_path): - os.makedirs(vs_path) - vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表 - torch_gc() - vector_store.save_local(vs_path) - if len(loaded_files) > 0: - file_status = f"成功上传文件 {file.filename}" - refresh_vs_cache(knowledge_base_name) - return BaseResponse(code=200, msg=file_status) - else: - file_status = f"上传文件 {file.filename} 失败" - return BaseResponse(code=500, msg=file_status) + docs2vs(docs, knowledge_base_name) + + return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}") async def delete_doc(knowledge_base_name: str, @@ -84,24 +66,24 @@ async def delete_doc(knowledge_base_name: str, knowledge_base_name = urllib.parse.unquote(knowledge_base_name) if not os.path.exists(get_kb_path(knowledge_base_name)): - return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_name} not found") + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") doc_path = get_file_path(knowledge_base_name, doc_name) if os.path.exists(doc_path): os.remove(doc_path) remain_docs = await list_docs(knowledge_base_name) if len(remain_docs.data) == 0: shutil.rmtree(get_kb_path(knowledge_base_name), ignore_errors=True) - return BaseResponse(code=200, msg=f"document {doc_name} delete success") + return BaseResponse(code=200, msg=f"{doc_name} 文件删除成功") else: # TODO: 重写从向量库中删除文件 status = "" # local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_name)) if "success" in status: refresh_vs_cache(knowledge_base_name) - return BaseResponse(code=200, msg=f"document {doc_name} delete success") + return BaseResponse(code=200, msg=f"{doc_name} 文件删除成功") else: - return BaseResponse(code=500, msg=f"document {doc_name} delete fail") + return BaseResponse(code=500, msg=f"{doc_name} 文件删除失败") else: - return BaseResponse(code=404, msg=f"document {doc_name} not found") + return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") async def update_doc(): diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 0504a85..c3f4f62 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,4 +1,6 @@ import os +from typing import List +from server.utils import torch_gc from configs.model_config import KB_ROOT_PATH from langchain.vectorstores import FAISS from langchain.embeddings.huggingface import HuggingFaceEmbeddings @@ -43,6 +45,23 @@ def file2text(filepath): docs = loader.load_and_split(text_splitter) return docs +def docs2vs( + docs: List[Document], + knowledge_base_name: str): + vs_path = get_vs_path(knowledge_base_name) + embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE) + if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): + vector_store = FAISS.load_local(vs_path, embeddings) + vector_store.add_documents(docs) + torch_gc() + else: + if not os.path.exists(vs_path): + os.makedirs(vs_path) + vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表 + torch_gc() + vector_store.save_local(vs_path) + refresh_vs_cache(knowledge_base_name) + @lru_cache(1) def load_embeddings(model: str, device: str):