From 3f045cedb93c512b2695b6dde02912e4db924d11 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sat, 5 Aug 2023 22:57:19 +0800 Subject: [PATCH] 1. add add_doc and list_docs to KnowledgeBase class 2. add DB_ROOT_PATH to model_config.py.example --- configs/model_config.py.example | 3 + server/knowledge_base/kb_doc_api.py | 16 +--- server/knowledge_base/knowledge_base.py | 121 ++++++++++++++++++++---- server/knowledge_base/knowledge_file.py | 18 ++-- 4 files changed, 124 insertions(+), 34 deletions(-) diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 4a636cc..6eb522d 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -239,6 +239,9 @@ if not os.path.exists(LOG_PATH): # 知识库默认存储路径 KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") +# 数据库默认存储路径 +DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") + # 缓存向量库数量 CACHED_VS_NUM = 1 diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 20407e0..71da946 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -7,7 +7,8 @@ from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_ get_file_path, refresh_vs_cache, get_vs_path) from fastapi.responses import StreamingResponse import json -from server.knowledge_base import KnowledgeFile, KnowledgeBase +from server.knowledge_base.knowledge_file import KnowledgeFile +from server.knowledge_base.knowledge_base import KnowledgeBase async def list_docs(knowledge_base_name: str): @@ -16,17 +17,10 @@ async def list_docs(knowledge_base_name: str): knowledge_base_name = urllib.parse.unquote(knowledge_base_name) kb_path = get_kb_path(knowledge_base_name) - local_doc_folder = get_doc_path(knowledge_base_name) if not os.path.exists(kb_path): return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) - if not os.path.exists(local_doc_folder): - all_doc_names = [] else: - all_doc_names = [ - doc - for doc in os.listdir(local_doc_folder) - if os.path.isfile(os.path.join(local_doc_folder, doc)) - ] + all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs() return ListResponse(data=all_doc_names) @@ -60,7 +54,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), kb_file = KnowledgeFile(filename=file.filename, knowledge_base_name=knowledge_base_name) kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) - kb.add_file(kb_file.file2text()) + kb.add_file(kb_file) return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}") @@ -121,7 +115,7 @@ async def recreate_vector_store(knowledge_base_name: str): knowledge_base_name=kb_name) print(f"processing {kb_file.filepath} to vector store.") kb = KnowledgeBase.load(knowledge_base_name=kb_name) - kb.add_file(kb_file.file2text()) + kb.add_file(kb_file) yield json.dumps({ "total": len(docs), "finished": i + 1, diff --git a/server/knowledge_base/knowledge_base.py b/server/knowledge_base/knowledge_base.py index 224f1ed..f9dc05e 100644 --- a/server/knowledge_base/knowledge_base.py +++ b/server/knowledge_base/knowledge_base.py @@ -5,23 +5,27 @@ import shutil from langchain.vectorstores import FAISS from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path, refresh_vs_cache, load_embeddings) -from configs.model_config import (KB_ROOT_PATH, embedding_model_dict, EMBEDDING_MODEL, - EMBEDDING_DEVICE) +from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, + EMBEDDING_DEVICE, DB_ROOT_PATH) from server.utils import torch_gc -from typing import List -from langchain.docstore.document import Document - +from server.knowledge_base.knowledge_file import KnowledgeFile SUPPORTED_VS_TYPES = ["faiss", "milvus"] -DB_ROOT = os.path.join(KB_ROOT_PATH, "info.db") def list_kbs_from_db(): - conn = sqlite3.connect(DB_ROOT) + conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() - c.execute(f'''SELECT KB_NAME - FROM KNOWLEDGE_BASE - WHERE FILE_COUNT>0 ''') + c.execute('''CREATE TABLE if not exists knowledge_base + (id INTEGER PRIMARY KEY AUTOINCREMENT, + kb_name TEXT, + vs_type TEXT, + embed_model TEXT, + file_count INTEGER, + create_time DATETIME) ''') + c.execute(f'''SELECT kb_name + FROM knowledge_base + WHERE file_count>0 ''') kbs = [i[0] for i in c.fetchall() if i] conn.commit() conn.close() @@ -29,7 +33,7 @@ def list_kbs_from_db(): def add_kb_to_db(kb_name, vs_type, embed_model): - conn = sqlite3.connect(DB_ROOT) + conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() # Create table c.execute('''CREATE TABLE if not exists knowledge_base @@ -50,8 +54,15 @@ def add_kb_to_db(kb_name, vs_type, embed_model): def kb_exists(kb_name): - conn = sqlite3.connect(DB_ROOT) + conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() + c.execute('''CREATE TABLE if not exists knowledge_base + (id INTEGER PRIMARY KEY AUTOINCREMENT, + kb_name TEXT, + vs_type TEXT, + embed_model TEXT, + file_count INTEGER, + create_time DATETIME) ''') c.execute(f'''SELECT COUNT(*) FROM knowledge_base WHERE kb_name="{kb_name}" ''') @@ -62,8 +73,15 @@ def kb_exists(kb_name): def load_kb_from_db(kb_name): - conn = sqlite3.connect(DB_ROOT) + conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() + c.execute('''CREATE TABLE if not exists knowledge_base + (id INTEGER PRIMARY KEY AUTOINCREMENT, + kb_name TEXT, + vs_type TEXT, + embed_model TEXT, + file_count INTEGER, + create_time DATETIME) ''') c.execute(f'''SELECT kb_name, vs_type, embed_model FROM knowledge_base WHERE kb_name="{kb_name}" ''') @@ -78,16 +96,82 @@ def load_kb_from_db(kb_name): def delete_kb_from_db(kb_name): - conn = sqlite3.connect(DB_ROOT) + conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() + # delete kb from table knowledge_base + c.execute('''CREATE TABLE if not exists knowledge_base + (id INTEGER PRIMARY KEY AUTOINCREMENT, + kb_name TEXT, + vs_type TEXT, + embed_model TEXT, + file_count INTEGER, + create_time DATETIME) ''') c.execute(f'''DELETE FROM knowledge_base WHERE kb_name="{kb_name}" ''') + # delete files in kb from table knowledge_files + c.execute('''CREATE TABLE if not exists knowledge_files + (id INTEGER PRIMARY KEY AUTOINCREMENT, + file_name TEXT, + file_ext TEXT, + kb_name TEXT, + document_loader_name TEXT, + text_splitter_name TEXT, + file_version INTEGER, + create_time DATETIME) ''') + # Insert a row of data + c.execute(f"""DELETE + FROM knowledge_files + WHERE kb_name="{kb_name}" + """) conn.commit() conn.close() return True +def list_docs_from_db(kb_name): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + c.execute('''CREATE TABLE if not exists knowledge_files + (id INTEGER PRIMARY KEY AUTOINCREMENT, + file_name TEXT, + file_ext TEXT, + kb_name TEXT, + document_loader_name TEXT, + text_splitter_name TEXT, + file_version INTEGER, + create_time DATETIME) ''') + c.execute(f'''SELECT file_name + FROM knowledge_files + WHERE kb_name="{kb_name}" ''') + kbs = [i[0] for i in c.fetchall() if i] + conn.commit() + conn.close() + return kbs + +def add_file_to_db(kb_file: KnowledgeFile): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + # Create table + c.execute('''CREATE TABLE if not exists knowledge_files + (id INTEGER PRIMARY KEY AUTOINCREMENT, + file_name TEXT, + file_ext TEXT, + kb_name TEXT, + document_loader_name TEXT, + text_splitter_name TEXT, + file_version INTEGER, + create_time DATETIME) ''') + # Insert a row of data + c.execute(f"""INSERT INTO knowledge_files + (file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time) + VALUES + ('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}', + '{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") + conn.commit() + conn.close() + + class KnowledgeBase: def __init__(self, knowledge_base_name: str, @@ -120,9 +204,10 @@ class KnowledgeBase: pass return True - def add_file(self, docs: List[Document]): + def add_doc(self, kb_file: KnowledgeFile): + docs = kb_file.file2text() vs_path = get_vs_path(self.kb_name) - embeddings = load_embeddings(embedding_model_dict[self.embed_model], EMBEDDING_DEVICE) + embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) if self.vs_type in ["faiss"]: if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): vector_store = FAISS.load_local(vs_path, embeddings) @@ -134,11 +219,15 @@ class KnowledgeBase: vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表 torch_gc() vector_store.save_local(vs_path) + add_file_to_db(kb_file) refresh_vs_cache(self.kb_name) elif self.vs_type in ["milvus"]: # TODO: 向milvus库中增加文件 pass + def list_docs(self): + return list_docs_from_db(self.kb_name) + @classmethod def exists(cls, knowledge_base_name: str): diff --git a/server/knowledge_base/knowledge_file.py b/server/knowledge_base/knowledge_file.py index 0ce66a4..a342d7d 100644 --- a/server/knowledge_base/knowledge_file.py +++ b/server/knowledge_base/knowledge_file.py @@ -1,8 +1,8 @@ import os.path from server.knowledge_base.utils import (get_file_path) -from server.knowledge_base import KnowledgeBase import sys + LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst', '.rtf', '.txt', '.xml', '.doc', '.docx', '.epub', '.odt', '.pdf', @@ -23,19 +23,23 @@ class KnowledgeFile: filename: str, knowledge_base_name: str ): - self.kb = KnowledgeBase.load(knowledge_base_name) + self.kb_name = knowledge_base_name self.filename = filename self.ext = os.path.splitext(filename)[-1] if self.ext not in SUPPORTED_EXTS: raise ValueError(f"暂未支持的文件格式 {self.ext}") self.filepath = get_file_path(knowledge_base_name, filename) self.docs = None - self.loader_class_name = get_LoaderClass(self.ext) + self.document_loader_name = get_LoaderClass(self.ext) + + # TODO: 增加依据文件格式匹配text_splitter + self.text_splitter_name = "CharacterTextSplitter" def file2text(self): - LoaderClass = getattr(sys.modules['langchain.document_loaders'], self.loader_class_name) - loader = LoaderClass(self.filepath) + DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name) + loader = DocumentLoader(self.filepath) - from langchain.text_splitter import CharacterTextSplitter - text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200) + # TODO: 增加依据文件格式匹配text_splitter + TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name) + text_splitter = TextSplitter(chunk_size=500, chunk_overlap=200) return loader.load_and_split(text_splitter)