From 206261cd0c6893eee3bdf8f86a889a777086d472 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sat, 5 Aug 2023 13:46:00 +0800 Subject: [PATCH] update class method of KnowledgeBase and KnowledgeFile --- server/knowledge_base/kb_api.py | 7 ++- server/knowledge_base/kb_doc_api.py | 20 ++++----- server/knowledge_base/knowledge_base.py | 57 +++++++++++++++++++------ server/knowledge_base/knowledge_file.py | 11 ++--- 4 files changed, 66 insertions(+), 29 deletions(-) diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index 3eda6c8..f035838 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -10,7 +10,9 @@ async def list_kbs(): async def create_kb(knowledge_base_name: str, - vector_store_type: str = "faiss"): + vector_store_type: str = "faiss", + embed_model: str = "m3e-base", + ): # Create selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -19,7 +21,8 @@ async def create_kb(knowledge_base_name: str, if KnowledgeBase.exists(knowledge_base_name): return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") kb = KnowledgeBase(knowledge_base_name=knowledge_base_name, - vector_store_type=vector_store_type) + vector_store_type=vector_store_type, + embed_model=embed_model) kb.create() return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 65957b1..20407e0 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -7,7 +7,7 @@ from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_ get_file_path, refresh_vs_cache, get_vs_path) from fastapi.responses import StreamingResponse import json -from server.knowledge_base.knowledge_file import KnowledgeFile +from server.knowledge_base import KnowledgeFile, KnowledgeBase async def list_docs(knowledge_base_name: str): @@ -59,8 +59,8 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), kb_file = KnowledgeFile(filename=file.filename, knowledge_base_name=knowledge_base_name) - kb_file.file2text() - kb_file.docs2vs() + kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) + kb.add_file(kb_file.file2text()) return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}") @@ -108,20 +108,20 @@ async def recreate_vector_store(knowledge_base_name: str): recreate vector store from the content. this is usefull when user can copy files to content folder directly instead of upload through network. ''' - async def output(kb): - vs_path = get_vs_path(kb) + async def output(kb_name): + vs_path = get_vs_path(kb_name) if os.path.isdir(vs_path): shutil.rmtree(vs_path) os.mkdir(vs_path) print(f"start to recreate vectore in {vs_path}") - docs = (await list_docs(kb)).data + docs = (await list_docs(kb_name)).data for i, filename in enumerate(docs): kb_file = KnowledgeFile(filename=filename, - knowledge_base_name=kb) - print(f"processing {get_file_path(kb, filename)} to vector store.") - kb_file.file2text() - kb_file.docs2vs() + knowledge_base_name=kb_name) + print(f"processing {kb_file.filepath} to vector store.") + kb = KnowledgeBase.load(knowledge_base_name=kb_name) + kb.add_file(kb_file.file2text()) yield json.dumps({ "total": len(docs), "finished": i + 1, diff --git a/server/knowledge_base/knowledge_base.py b/server/knowledge_base/knowledge_base.py index 5f49c8d..ad0aa6a 100644 --- a/server/knowledge_base/knowledge_base.py +++ b/server/knowledge_base/knowledge_base.py @@ -1,9 +1,16 @@ -from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path) import os import sqlite3 -from configs.model_config import KB_ROOT_PATH import datetime import shutil +from langchain.vectorstores import FAISS +from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path, + refresh_vs_cache, load_embeddings) +from configs.model_config import (KB_ROOT_PATH, embedding_model_dict, EMBEDDING_MODEL, + EMBEDDING_DEVICE) +from server.utils import torch_gc +from typing import List +from langchain.docstore.document import Document + SUPPORTED_VS_TYPES = ["faiss", "milvus"] DB_ROOT = os.path.join(KB_ROOT_PATH, "info.db") @@ -22,7 +29,8 @@ def list_kbs_from_db(): conn.close() return kbs -def add_kb_to_db(kb_name, vs_type): + +def add_kb_to_db(kb_name, vs_type, embed_model): conn = sqlite3.connect(DB_ROOT) c = conn.cursor() # Create table @@ -30,13 +38,15 @@ def add_kb_to_db(kb_name, vs_type): (ID INTEGER PRIMARY KEY AUTOINCREMENT, KB_NAME TEXT, VS_TYPE TEXT, + EMBED_MODEL TEXT, FILE_COUNT INTEGER, CREATE_TIME DATETIME) ''') # Insert a row of data c.execute(f"""INSERT INTO KNOWLEDGE_BASE - (KB_NAME, VS_TYPE, FILE_COUNT, CREATE_TIME) + (KB_NAME, VS_TYPE, EMBED_MODEL, FILE_COUNT, CREATE_TIME) VALUES - ('{kb_name}','{vs_type}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") + ('{kb_name}','{vs_type}','{embed_model}', + 0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") conn.commit() conn.close() @@ -56,17 +66,17 @@ def kb_exists(kb_name): def load_kb_from_db(kb_name): conn = sqlite3.connect(DB_ROOT) c = conn.cursor() - c.execute(f'''SELECT KB_NAME, VS_TYPE + c.execute(f'''SELECT KB_NAME, VS_TYPE, EMBED_MODEL FROM KNOWLEDGE_BASE WHERE KB_NAME="{kb_name}" ''') resp = c.fetchone() if resp: - kb_name, vs_type = resp + kb_name, vs_type, embed_model = resp else: - kb_name, vs_type = None, None + kb_name, vs_type, embed_model = None, None, None conn.commit() conn.close() - return kb_name, vs_type + return kb_name, vs_type, embed_model def delete_kb_from_db(kb_name): @@ -84,11 +94,15 @@ class KnowledgeBase: def __init__(self, knowledge_base_name: str, vector_store_type: str, + embed_model: str, ): self.kb_name = knowledge_base_name if vector_store_type not in SUPPORTED_VS_TYPES: raise ValueError(f"暂未支持向量库类型 {vector_store_type}") self.vs_type = vector_store_type + if embed_model not in embedding_model_dict.keys(): + raise ValueError(f"暂未支持embedding模型 {embed_model}") + self.embed_model = embed_model self.kb_path = get_kb_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name) if self.vs_type in ["faiss"]: @@ -102,12 +116,31 @@ class KnowledgeBase: if self.vs_type in ["faiss"]: if not os.path.exists(self.vs_path): os.makedirs(self.vs_path) - add_kb_to_db(self.kb_name, self.vs_type) + add_kb_to_db(self.kb_name, self.vs_type, self.embed_model) elif self.vs_type in ["milvus"]: # TODO: 创建milvus库 pass return True + def add_file(self, docs: List[Document]): + vs_path = get_vs_path(self.kb_name) + embeddings = load_embeddings(embedding_model_dict[self.embed_model], EMBEDDING_DEVICE) + if self.vs_type in ["faiss"]: + if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): + vector_store = FAISS.load_local(vs_path, embeddings) + vector_store.add_documents(docs) + torch_gc() + else: + if not os.path.exists(vs_path): + os.makedirs(vs_path) + vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表 + torch_gc() + vector_store.save_local(vs_path) + refresh_vs_cache(self.kb_name) + elif self.vs_type in ["milvus"]: + # TODO: 向milvus库中增加文件 + pass + @classmethod def exists(cls, knowledge_base_name: str): @@ -116,8 +149,8 @@ class KnowledgeBase: @classmethod def load(cls, knowledge_base_name: str): - kb_name, vs_type = load_kb_from_db(knowledge_base_name) - return cls(kb_name, vs_type) + kb_name, vs_type, embed_model = load_kb_from_db(knowledge_base_name) + return cls(kb_name, vs_type, embed_model) @classmethod def delete(cls, diff --git a/server/knowledge_base/knowledge_file.py b/server/knowledge_base/knowledge_file.py index 52a5699..6248a16 100644 --- a/server/knowledge_base/knowledge_file.py +++ b/server/knowledge_base/knowledge_file.py @@ -4,6 +4,7 @@ from server.knowledge_base.utils import (get_file_path, get_vs_path, from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE) from langchain.vectorstores import FAISS from server.utils import torch_gc +from server.knowledge_base import KnowledgeBase class KnowledgeFile: @@ -12,7 +13,7 @@ class KnowledgeFile: filename: str, knowledge_base_name: str ): - self.knowledge_base_name = knowledge_base_name + self.kb = KnowledgeBase.load(knowledge_base_name) self.knowledge_base_type = "faiss" self.filename = filename self.ext = os.path.splitext(filename)[-1] @@ -28,12 +29,12 @@ class KnowledgeFile: from langchain.text_splitter import CharacterTextSplitter text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200) - self.docs = loader.load_and_split(text_splitter) - return True + return loader.load_and_split(text_splitter) def docs2vs(self): - vs_path = get_vs_path(self.knowledge_base_name) + vs_path = get_vs_path(self.kb.kb_name) embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE) + if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): vector_store = FAISS.load_local(vs_path, embeddings) vector_store.add_documents(self.docs) @@ -44,5 +45,5 @@ class KnowledgeFile: vector_store = FAISS.from_documents(self.docs, embeddings) # docs 为Document列表 torch_gc() vector_store.save_local(vs_path) - refresh_vs_cache(self.knowledge_base_name) + refresh_vs_cache(self.kb.kb_name) return True \ No newline at end of file