1. add add_doc and list_docs to KnowledgeBase class

2. add DB_ROOT_PATH to model_config.py.example
This commit is contained in:
imClumsyPanda 2023-08-05 22:57:19 +08:00
parent 313e590961
commit 3f045cedb9
4 changed files with 124 additions and 34 deletions

View File

@ -239,6 +239,9 @@ if not os.path.exists(LOG_PATH):
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
# 数据库默认存储路径
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
# 缓存向量库数量
CACHED_VS_NUM = 1

View File

@ -7,7 +7,8 @@ from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_
get_file_path, refresh_vs_cache, get_vs_path)
from fastapi.responses import StreamingResponse
import json
from server.knowledge_base import KnowledgeFile, KnowledgeBase
from server.knowledge_base.knowledge_file import KnowledgeFile
from server.knowledge_base.knowledge_base import KnowledgeBase
async def list_docs(knowledge_base_name: str):
@ -16,17 +17,10 @@ async def list_docs(knowledge_base_name: str):
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
kb_path = get_kb_path(knowledge_base_name)
local_doc_folder = get_doc_path(knowledge_base_name)
if not os.path.exists(kb_path):
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
if not os.path.exists(local_doc_folder):
all_doc_names = []
else:
all_doc_names = [
doc
for doc in os.listdir(local_doc_folder)
if os.path.isfile(os.path.join(local_doc_folder, doc))
]
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs()
return ListResponse(data=all_doc_names)
@ -60,7 +54,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
kb.add_file(kb_file.file2text())
kb.add_file(kb_file)
return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}")
@ -121,7 +115,7 @@ async def recreate_vector_store(knowledge_base_name: str):
knowledge_base_name=kb_name)
print(f"processing {kb_file.filepath} to vector store.")
kb = KnowledgeBase.load(knowledge_base_name=kb_name)
kb.add_file(kb_file.file2text())
kb.add_file(kb_file)
yield json.dumps({
"total": len(docs),
"finished": i + 1,

View File

@ -5,23 +5,27 @@ import shutil
from langchain.vectorstores import FAISS
from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path,
refresh_vs_cache, load_embeddings)
from configs.model_config import (KB_ROOT_PATH, embedding_model_dict, EMBEDDING_MODEL,
EMBEDDING_DEVICE)
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL,
EMBEDDING_DEVICE, DB_ROOT_PATH)
from server.utils import torch_gc
from typing import List
from langchain.docstore.document import Document
from server.knowledge_base.knowledge_file import KnowledgeFile
SUPPORTED_VS_TYPES = ["faiss", "milvus"]
DB_ROOT = os.path.join(KB_ROOT_PATH, "info.db")
def list_kbs_from_db():
conn = sqlite3.connect(DB_ROOT)
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute(f'''SELECT KB_NAME
FROM KNOWLEDGE_BASE
WHERE FILE_COUNT>0 ''')
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT kb_name
FROM knowledge_base
WHERE file_count>0 ''')
kbs = [i[0] for i in c.fetchall() if i]
conn.commit()
conn.close()
@ -29,7 +33,7 @@ def list_kbs_from_db():
def add_kb_to_db(kb_name, vs_type, embed_model):
conn = sqlite3.connect(DB_ROOT)
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
# Create table
c.execute('''CREATE TABLE if not exists knowledge_base
@ -50,8 +54,15 @@ def add_kb_to_db(kb_name, vs_type, embed_model):
def kb_exists(kb_name):
conn = sqlite3.connect(DB_ROOT)
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT COUNT(*)
FROM knowledge_base
WHERE kb_name="{kb_name}" ''')
@ -62,8 +73,15 @@ def kb_exists(kb_name):
def load_kb_from_db(kb_name):
conn = sqlite3.connect(DB_ROOT)
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT kb_name, vs_type, embed_model
FROM knowledge_base
WHERE kb_name="{kb_name}" ''')
@ -78,16 +96,82 @@ def load_kb_from_db(kb_name):
def delete_kb_from_db(kb_name):
conn = sqlite3.connect(DB_ROOT)
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
# delete kb from table knowledge_base
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
c.execute(f'''DELETE
FROM knowledge_base
WHERE kb_name="{kb_name}" ''')
# delete files in kb from table knowledge_files
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
# Insert a row of data
c.execute(f"""DELETE
FROM knowledge_files
WHERE kb_name="{kb_name}"
""")
conn.commit()
conn.close()
return True
def list_docs_from_db(kb_name):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT file_name
FROM knowledge_files
WHERE kb_name="{kb_name}" ''')
kbs = [i[0] for i in c.fetchall() if i]
conn.commit()
conn.close()
return kbs
def add_file_to_db(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
# Create table
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
# Insert a row of data
c.execute(f"""INSERT INTO knowledge_files
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time)
VALUES
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}',
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
conn.commit()
conn.close()
class KnowledgeBase:
def __init__(self,
knowledge_base_name: str,
@ -120,9 +204,10 @@ class KnowledgeBase:
pass
return True
def add_file(self, docs: List[Document]):
def add_doc(self, kb_file: KnowledgeFile):
docs = kb_file.file2text()
vs_path = get_vs_path(self.kb_name)
embeddings = load_embeddings(embedding_model_dict[self.embed_model], EMBEDDING_DEVICE)
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
if self.vs_type in ["faiss"]:
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
vector_store = FAISS.load_local(vs_path, embeddings)
@ -134,11 +219,15 @@ class KnowledgeBase:
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
torch_gc()
vector_store.save_local(vs_path)
add_file_to_db(kb_file)
refresh_vs_cache(self.kb_name)
elif self.vs_type in ["milvus"]:
# TODO: 向milvus库中增加文件
pass
def list_docs(self):
return list_docs_from_db(self.kb_name)
@classmethod
def exists(cls,
knowledge_base_name: str):

View File

@ -1,8 +1,8 @@
import os.path
from server.knowledge_base.utils import (get_file_path)
from server.knowledge_base import KnowledgeBase
import sys
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt', '.pdf',
@ -23,19 +23,23 @@ class KnowledgeFile:
filename: str,
knowledge_base_name: str
):
self.kb = KnowledgeBase.load(knowledge_base_name)
self.kb_name = knowledge_base_name
self.filename = filename
self.ext = os.path.splitext(filename)[-1]
if self.ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {self.ext}")
self.filepath = get_file_path(knowledge_base_name, filename)
self.docs = None
self.loader_class_name = get_LoaderClass(self.ext)
self.document_loader_name = get_LoaderClass(self.ext)
# TODO: 增加依据文件格式匹配text_splitter
self.text_splitter_name = "CharacterTextSplitter"
def file2text(self):
LoaderClass = getattr(sys.modules['langchain.document_loaders'], self.loader_class_name)
loader = LoaderClass(self.filepath)
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
loader = DocumentLoader(self.filepath)
from langchain.text_splitter import CharacterTextSplitter
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200)
# TODO: 增加依据文件格式匹配text_splitter
TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name)
text_splitter = TextSplitter(chunk_size=500, chunk_overlap=200)
return loader.load_and_split(text_splitter)