From a447529c2e5a63c88deb1fa6ee4e5da44bc6ce6d Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sat, 5 Aug 2023 23:35:20 +0800 Subject: [PATCH] update kb_doc_api.py --- server/knowledge_base/kb_doc_api.py | 51 ++++++++++------------- server/knowledge_base/knowledge_base.py | 55 +++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 30 deletions(-) diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 9330e69..d116c58 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -31,32 +31,31 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") - saved_path = get_doc_path(knowledge_base_name) - if not os.path.exists(saved_path): + if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) + file_content = await file.read() # 读取上传文件的内容 - file_path = os.path.join(saved_path, file.filename) - if (os.path.exists(file_path) + kb_file = KnowledgeFile(filename=file.filename, + knowledge_base_name=knowledge_base_name) + + if (os.path.exists(kb_file.filepath) and not override - and os.path.getsize(file_path) == len(file_content) + and os.path.getsize(kb_file.filepath) == len(file_content) ): - file_status = f"文件 {file.filename} 已存在。" + file_status = f"文件 {kb_file.filename} 已存在。" return BaseResponse(code=404, msg=file_status) try: - with open(file_path, "wb") as f: + with open(kb_file.filepath, "wb") as f: f.write(file_content) except Exception as e: - return BaseResponse(code=500, msg=f"{file.filename} 文件上传失败,报错信息为: {e}") + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") - kb_file = KnowledgeFile(filename=file.filename, - knowledge_base_name=knowledge_base_name) - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) kb.add_doc(kb_file) - - return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}") + return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") async def delete_doc(knowledge_base_name: str, @@ -66,25 +65,17 @@ async def delete_doc(knowledge_base_name: str, return BaseResponse(code=403, msg="Don't attack me") knowledge_base_name = urllib.parse.unquote(knowledge_base_name) - if not os.path.exists(get_kb_path(knowledge_base_name)): + if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - doc_path = get_file_path(knowledge_base_name, doc_name) - if os.path.exists(doc_path): - os.remove(doc_path) - remain_docs = await list_docs(knowledge_base_name) - if len(remain_docs.data) == 0: - shutil.rmtree(get_kb_path(knowledge_base_name), ignore_errors=True) - return BaseResponse(code=200, msg=f"{doc_name} 文件删除成功") - else: - # TODO: 重写从向量库中删除文件 - status = "" # local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_name)) - if "success" in status: - refresh_vs_cache(knowledge_base_name) - return BaseResponse(code=200, msg=f"{doc_name} 文件删除成功") - else: - return BaseResponse(code=500, msg=f"{doc_name} 文件删除失败") - else: + + kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) + if not kb.exist_doc(doc_name): return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") + kb_file = KnowledgeFile(filename=doc_name, + knowledge_base_name=knowledge_base_name) + kb.delete_doc(kb_file) + return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功") + # return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败") async def update_doc(): diff --git a/server/knowledge_base/knowledge_base.py b/server/knowledge_base/knowledge_base.py index deab893..de7f661 100644 --- a/server/knowledge_base/knowledge_base.py +++ b/server/knowledge_base/knowledge_base.py @@ -172,6 +172,50 @@ def add_doc_to_db(kb_file: KnowledgeFile): conn.commit() conn.close() +def delete_file_from_db(kb_file: KnowledgeFile): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + # delete files in kb from table knowledge_files + c.execute('''CREATE TABLE if not exists knowledge_files + (id INTEGER PRIMARY KEY AUTOINCREMENT, + file_name TEXT, + file_ext TEXT, + kb_name TEXT, + document_loader_name TEXT, + text_splitter_name TEXT, + file_version INTEGER, + create_time DATETIME) ''') + # Insert a row of data + c.execute(f"""DELETE + FROM knowledge_files + WHERE file_name="{kb_file.filename}" + AND kb_name="{kb_file.kb_name}" + """) + conn.commit() + conn.close() + return True + +def doc_exists(kb_file: KnowledgeFile): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + c.execute('''CREATE TABLE if not exists knowledge_files + (id INTEGER PRIMARY KEY AUTOINCREMENT, + file_name TEXT, + file_ext TEXT, + kb_name TEXT, + document_loader_name TEXT, + text_splitter_name TEXT, + file_version INTEGER, + create_time DATETIME) ''') + c.execute(f'''SELECT COUNT(*) + FROM knowledge_files + WHERE file_name="{kb_file.filename}" + AND kb_name="{kb_file.kb_name}" ''') + status = True if c.fetchone()[0] else False + conn.commit() + conn.close() + return status + class KnowledgeBase: def __init__(self, @@ -226,6 +270,17 @@ class KnowledgeBase: # TODO: 向milvus库中增加文件 pass + def delete_doc(self, kb_file: KnowledgeFile): + if os.path.exists(kb_file.filepath): + os.remove(kb_file.filepath) + if self.vs_type in ["faiss"]: + # TODO: 从FAISS向量库中删除文档 + delete_file_from_db(kb_file) + + def exist_doc(self, file_name: str): + return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, + filename=file_name)) + def list_docs(self): return list_docs_from_db(self.kb_name)