From 6c7adfbaebb35eb4914b3c2e60005217844c43df Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sat, 5 Aug 2023 03:15:41 +0800 Subject: [PATCH] change kb_api functions with KnowledgeBase class method --- server/api.py | 5 + server/knowledge_base/__init__.py | 4 +- server/knowledge_base/kb_api.py | 42 ++++---- server/knowledge_base/knowledge_base.py | 127 +++++++++++++++++++++++- 4 files changed, 151 insertions(+), 27 deletions(-) diff --git a/server/api.py b/server/api.py index d030d46..c24bcf7 100644 --- a/server/api.py +++ b/server/api.py @@ -2,6 +2,7 @@ import nltk import sys import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) + from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN import argparse import uvicorn @@ -98,13 +99,17 @@ def create_app(): response_model=BaseResponse, summary="上传文件到知识库,并删除另一个文件" )(update_doc) + app.post("/knowledge_base/recreate_vector_store", + tags=["Knowledge Base Management"], summary="根据content中文档重建向量库,流式输出处理进度。" )(recreate_vector_store) return app + app = create_app() + def run_api(host, port, **kwargs): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): uvicorn.run(app, diff --git a/server/knowledge_base/__init__.py b/server/knowledge_base/__init__.py index 427379b..14b7898 100644 --- a/server/knowledge_base/__init__.py +++ b/server/knowledge_base/__init__.py @@ -1,2 +1,4 @@ from .kb_api import list_kbs, create_kb, delete_kb -from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc +from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store +from .knowledge_base import KnowledgeBase +from .knowledge_file import KnowledgeFile diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index 4250640..3eda6c8 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -1,38 +1,26 @@ -import os import urllib -import shutil -from configs.model_config import KB_ROOT_PATH from server.utils import BaseResponse, ListResponse -from server.knowledge_base.utils import validate_kb_name, get_kb_path, get_vs_path +from server.knowledge_base.utils import validate_kb_name +from server.knowledge_base.knowledge_base import KnowledgeBase async def list_kbs(): # Get List of Knowledge Base - if not os.path.exists(KB_ROOT_PATH): - all_doc_ids = [] - else: - all_doc_ids = [ - folder - for folder in os.listdir(KB_ROOT_PATH) - if os.path.isdir(os.path.join(KB_ROOT_PATH, folder)) - and os.path.exists(os.path.join(KB_ROOT_PATH, folder, "vector_store", "index.faiss")) - ] - - return ListResponse(data=all_doc_ids) + return ListResponse(data=KnowledgeBase.list_kbs()) -async def create_kb(knowledge_base_name: str): +async def create_kb(knowledge_base_name: str, + vector_store_type: str = "faiss"): # Create selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") if knowledge_base_name is None or knowledge_base_name.strip() == "": return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") - if os.path.exists(get_kb_path(knowledge_base_name)): + if KnowledgeBase.exists(knowledge_base_name): return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") - if not os.path.exists(os.path.join(KB_ROOT_PATH, knowledge_base_name, "content")): - os.makedirs(os.path.join(KB_ROOT_PATH, knowledge_base_name, "content")) - if not os.path.exists(os.path.join(KB_ROOT_PATH, knowledge_base_name, "vector_store")): - os.makedirs(get_vs_path(knowledge_base_name)) + kb = KnowledgeBase(knowledge_base_name=knowledge_base_name, + vector_store_type=vector_store_type) + kb.create() return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") @@ -41,8 +29,12 @@ async def delete_kb(knowledge_base_name: str): if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") knowledge_base_name = urllib.parse.unquote(knowledge_base_name) - kb_path = get_kb_path(knowledge_base_name) - if not os.path.exists(kb_path): + + if not KnowledgeBase.exists(knowledge_base_name): return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - shutil.rmtree(kb_path) - return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") + + status = KnowledgeBase.delete(knowledge_base_name) + if status: + return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") + else: + return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") diff --git a/server/knowledge_base/knowledge_base.py b/server/knowledge_base/knowledge_base.py index 65f9bad..5f49c8d 100644 --- a/server/knowledge_base/knowledge_base.py +++ b/server/knowledge_base/knowledge_base.py @@ -1,6 +1,83 @@ from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path) +import os +import sqlite3 +from configs.model_config import KB_ROOT_PATH +import datetime +import shutil SUPPORTED_VS_TYPES = ["faiss", "milvus"] +DB_ROOT = os.path.join(KB_ROOT_PATH, "info.db") + + +# TODO: 知识库信息入库 + +def list_kbs_from_db(): + conn = sqlite3.connect(DB_ROOT) + c = conn.cursor() + c.execute(f'''SELECT KB_NAME + FROM KNOWLEDGE_BASE + WHERE FILE_COUNT>0 ''') + kbs = [i[0] for i in c.fetchall() if i] + conn.commit() + conn.close() + return kbs + +def add_kb_to_db(kb_name, vs_type): + conn = sqlite3.connect(DB_ROOT) + c = conn.cursor() + # Create table + c.execute('''CREATE TABLE if not exists KNOWLEDGE_BASE + (ID INTEGER PRIMARY KEY AUTOINCREMENT, + KB_NAME TEXT, + VS_TYPE TEXT, + FILE_COUNT INTEGER, + CREATE_TIME DATETIME) ''') + # Insert a row of data + c.execute(f"""INSERT INTO KNOWLEDGE_BASE + (KB_NAME, VS_TYPE, FILE_COUNT, CREATE_TIME) + VALUES + ('{kb_name}','{vs_type}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") + conn.commit() + conn.close() + + +def kb_exists(kb_name): + conn = sqlite3.connect(DB_ROOT) + c = conn.cursor() + c.execute(f'''SELECT COUNT(*) + FROM KNOWLEDGE_BASE + WHERE KB_NAME="{kb_name}" ''') + status = True if c.fetchone()[0] else False + conn.commit() + conn.close() + return status + + +def load_kb_from_db(kb_name): + conn = sqlite3.connect(DB_ROOT) + c = conn.cursor() + c.execute(f'''SELECT KB_NAME, VS_TYPE + FROM KNOWLEDGE_BASE + WHERE KB_NAME="{kb_name}" ''') + resp = c.fetchone() + if resp: + kb_name, vs_type = resp + else: + kb_name, vs_type = None, None + conn.commit() + conn.close() + return kb_name, vs_type + + +def delete_kb_from_db(kb_name): + conn = sqlite3.connect(DB_ROOT) + c = conn.cursor() + c.execute(f'''DELETE + FROM KNOWLEDGE_BASE + WHERE KB_NAME="{kb_name}" ''') + conn.commit() + conn.close() + return True class KnowledgeBase: @@ -15,4 +92,52 @@ class KnowledgeBase: self.kb_path = get_kb_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name) if self.vs_type in ["faiss"]: - self.vs_path = get_vs_path(self.kb_name) \ No newline at end of file + self.vs_path = get_vs_path(self.kb_name) + elif self.vs_type in ["milvus"]: + pass + + def create(self): + if not os.path.exists(self.doc_path): + os.makedirs(self.doc_path) + if self.vs_type in ["faiss"]: + if not os.path.exists(self.vs_path): + os.makedirs(self.vs_path) + add_kb_to_db(self.kb_name, self.vs_type) + elif self.vs_type in ["milvus"]: + # TODO: 创建milvus库 + pass + return True + + @classmethod + def exists(cls, + knowledge_base_name: str): + return kb_exists(knowledge_base_name) + + @classmethod + def load(cls, + knowledge_base_name: str): + kb_name, vs_type = load_kb_from_db(knowledge_base_name) + return cls(kb_name, vs_type) + + @classmethod + def delete(cls, + knowledge_base_name: str): + kb = cls.load(knowledge_base_name) + if kb.vs_type in ["faiss"]: + shutil.rmtree(kb.kb_path) + elif kb.vs_type in ["milvus"]: + # TODO: 删除milvus库 + pass + status = delete_kb_from_db(knowledge_base_name) + return status + + @classmethod + def list_kbs(cls): + return list_kbs_from_db() + + +if __name__ == "__main__": + # kb = KnowledgeBase("123", "faiss") + # kb.create() + kb = KnowledgeBase.load(knowledge_base_name="123") + print()