change kb_api functions with KnowledgeBase class method
This commit is contained in:
parent
7e1472a95b
commit
6c7adfbaeb
|
|
@ -2,6 +2,7 @@ import nltk
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
|
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
|
||||||
import argparse
|
import argparse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
@ -98,13 +99,17 @@ def create_app():
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
summary="上传文件到知识库,并删除另一个文件"
|
summary="上传文件到知识库,并删除另一个文件"
|
||||||
)(update_doc)
|
)(update_doc)
|
||||||
|
|
||||||
app.post("/knowledge_base/recreate_vector_store",
|
app.post("/knowledge_base/recreate_vector_store",
|
||||||
|
tags=["Knowledge Base Management"],
|
||||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||||
)(recreate_vector_store)
|
)(recreate_vector_store)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
|
||||||
|
|
||||||
def run_api(host, port, **kwargs):
|
def run_api(host, port, **kwargs):
|
||||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||||
uvicorn.run(app,
|
uvicorn.run(app,
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,4 @@
|
||||||
from .kb_api import list_kbs, create_kb, delete_kb
|
from .kb_api import list_kbs, create_kb, delete_kb
|
||||||
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc
|
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
||||||
|
from .knowledge_base import KnowledgeBase
|
||||||
|
from .knowledge_file import KnowledgeFile
|
||||||
|
|
|
||||||
|
|
@ -1,38 +1,26 @@
|
||||||
import os
|
|
||||||
import urllib
|
import urllib
|
||||||
import shutil
|
|
||||||
from configs.model_config import KB_ROOT_PATH
|
|
||||||
from server.utils import BaseResponse, ListResponse
|
from server.utils import BaseResponse, ListResponse
|
||||||
from server.knowledge_base.utils import validate_kb_name, get_kb_path, get_vs_path
|
from server.knowledge_base.utils import validate_kb_name
|
||||||
|
from server.knowledge_base.knowledge_base import KnowledgeBase
|
||||||
|
|
||||||
|
|
||||||
async def list_kbs():
|
async def list_kbs():
|
||||||
# Get List of Knowledge Base
|
# Get List of Knowledge Base
|
||||||
if not os.path.exists(KB_ROOT_PATH):
|
return ListResponse(data=KnowledgeBase.list_kbs())
|
||||||
all_doc_ids = []
|
|
||||||
else:
|
|
||||||
all_doc_ids = [
|
|
||||||
folder
|
|
||||||
for folder in os.listdir(KB_ROOT_PATH)
|
|
||||||
if os.path.isdir(os.path.join(KB_ROOT_PATH, folder))
|
|
||||||
and os.path.exists(os.path.join(KB_ROOT_PATH, folder, "vector_store", "index.faiss"))
|
|
||||||
]
|
|
||||||
|
|
||||||
return ListResponse(data=all_doc_ids)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_kb(knowledge_base_name: str):
|
async def create_kb(knowledge_base_name: str,
|
||||||
|
vector_store_type: str = "faiss"):
|
||||||
# Create selected knowledge base
|
# Create selected knowledge base
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
||||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||||
if os.path.exists(get_kb_path(knowledge_base_name)):
|
if KnowledgeBase.exists(knowledge_base_name):
|
||||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||||
if not os.path.exists(os.path.join(KB_ROOT_PATH, knowledge_base_name, "content")):
|
kb = KnowledgeBase(knowledge_base_name=knowledge_base_name,
|
||||||
os.makedirs(os.path.join(KB_ROOT_PATH, knowledge_base_name, "content"))
|
vector_store_type=vector_store_type)
|
||||||
if not os.path.exists(os.path.join(KB_ROOT_PATH, knowledge_base_name, "vector_store")):
|
kb.create()
|
||||||
os.makedirs(get_vs_path(knowledge_base_name))
|
|
||||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -41,8 +29,12 @@ async def delete_kb(knowledge_base_name: str):
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||||
kb_path = get_kb_path(knowledge_base_name)
|
|
||||||
if not os.path.exists(kb_path):
|
if not KnowledgeBase.exists(knowledge_base_name):
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
shutil.rmtree(kb_path)
|
|
||||||
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
status = KnowledgeBase.delete(knowledge_base_name)
|
||||||
|
if status:
|
||||||
|
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
||||||
|
else:
|
||||||
|
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,83 @@
|
||||||
from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path)
|
from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path)
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from configs.model_config import KB_ROOT_PATH
|
||||||
|
import datetime
|
||||||
|
import shutil
|
||||||
|
|
||||||
SUPPORTED_VS_TYPES = ["faiss", "milvus"]
|
SUPPORTED_VS_TYPES = ["faiss", "milvus"]
|
||||||
|
DB_ROOT = os.path.join(KB_ROOT_PATH, "info.db")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: 知识库信息入库
|
||||||
|
|
||||||
|
def list_kbs_from_db():
|
||||||
|
conn = sqlite3.connect(DB_ROOT)
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute(f'''SELECT KB_NAME
|
||||||
|
FROM KNOWLEDGE_BASE
|
||||||
|
WHERE FILE_COUNT>0 ''')
|
||||||
|
kbs = [i[0] for i in c.fetchall() if i]
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return kbs
|
||||||
|
|
||||||
|
def add_kb_to_db(kb_name, vs_type):
|
||||||
|
conn = sqlite3.connect(DB_ROOT)
|
||||||
|
c = conn.cursor()
|
||||||
|
# Create table
|
||||||
|
c.execute('''CREATE TABLE if not exists KNOWLEDGE_BASE
|
||||||
|
(ID INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
KB_NAME TEXT,
|
||||||
|
VS_TYPE TEXT,
|
||||||
|
FILE_COUNT INTEGER,
|
||||||
|
CREATE_TIME DATETIME) ''')
|
||||||
|
# Insert a row of data
|
||||||
|
c.execute(f"""INSERT INTO KNOWLEDGE_BASE
|
||||||
|
(KB_NAME, VS_TYPE, FILE_COUNT, CREATE_TIME)
|
||||||
|
VALUES
|
||||||
|
('{kb_name}','{vs_type}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def kb_exists(kb_name):
|
||||||
|
conn = sqlite3.connect(DB_ROOT)
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute(f'''SELECT COUNT(*)
|
||||||
|
FROM KNOWLEDGE_BASE
|
||||||
|
WHERE KB_NAME="{kb_name}" ''')
|
||||||
|
status = True if c.fetchone()[0] else False
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
def load_kb_from_db(kb_name):
|
||||||
|
conn = sqlite3.connect(DB_ROOT)
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute(f'''SELECT KB_NAME, VS_TYPE
|
||||||
|
FROM KNOWLEDGE_BASE
|
||||||
|
WHERE KB_NAME="{kb_name}" ''')
|
||||||
|
resp = c.fetchone()
|
||||||
|
if resp:
|
||||||
|
kb_name, vs_type = resp
|
||||||
|
else:
|
||||||
|
kb_name, vs_type = None, None
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return kb_name, vs_type
|
||||||
|
|
||||||
|
|
||||||
|
def delete_kb_from_db(kb_name):
|
||||||
|
conn = sqlite3.connect(DB_ROOT)
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute(f'''DELETE
|
||||||
|
FROM KNOWLEDGE_BASE
|
||||||
|
WHERE KB_NAME="{kb_name}" ''')
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeBase:
|
class KnowledgeBase:
|
||||||
|
|
@ -15,4 +92,52 @@ class KnowledgeBase:
|
||||||
self.kb_path = get_kb_path(self.kb_name)
|
self.kb_path = get_kb_path(self.kb_name)
|
||||||
self.doc_path = get_doc_path(self.kb_name)
|
self.doc_path = get_doc_path(self.kb_name)
|
||||||
if self.vs_type in ["faiss"]:
|
if self.vs_type in ["faiss"]:
|
||||||
self.vs_path = get_vs_path(self.kb_name)
|
self.vs_path = get_vs_path(self.kb_name)
|
||||||
|
elif self.vs_type in ["milvus"]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
if not os.path.exists(self.doc_path):
|
||||||
|
os.makedirs(self.doc_path)
|
||||||
|
if self.vs_type in ["faiss"]:
|
||||||
|
if not os.path.exists(self.vs_path):
|
||||||
|
os.makedirs(self.vs_path)
|
||||||
|
add_kb_to_db(self.kb_name, self.vs_type)
|
||||||
|
elif self.vs_type in ["milvus"]:
|
||||||
|
# TODO: 创建milvus库
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def exists(cls,
|
||||||
|
knowledge_base_name: str):
|
||||||
|
return kb_exists(knowledge_base_name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls,
|
||||||
|
knowledge_base_name: str):
|
||||||
|
kb_name, vs_type = load_kb_from_db(knowledge_base_name)
|
||||||
|
return cls(kb_name, vs_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete(cls,
|
||||||
|
knowledge_base_name: str):
|
||||||
|
kb = cls.load(knowledge_base_name)
|
||||||
|
if kb.vs_type in ["faiss"]:
|
||||||
|
shutil.rmtree(kb.kb_path)
|
||||||
|
elif kb.vs_type in ["milvus"]:
|
||||||
|
# TODO: 删除milvus库
|
||||||
|
pass
|
||||||
|
status = delete_kb_from_db(knowledge_base_name)
|
||||||
|
return status
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_kbs(cls):
|
||||||
|
return list_kbs_from_db()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# kb = KnowledgeBase("123", "faiss")
|
||||||
|
# kb.create()
|
||||||
|
kb = KnowledgeBase.load(knowledge_base_name="123")
|
||||||
|
print()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue