update class method of KnowledgeBase and KnowledgeFile

This commit is contained in:
imClumsyPanda 2023-08-05 13:46:00 +08:00
parent 6c7adfbaeb
commit 206261cd0c
4 changed files with 66 additions and 29 deletions

View File

@ -10,7 +10,9 @@ async def list_kbs():
async def create_kb(knowledge_base_name: str, async def create_kb(knowledge_base_name: str,
vector_store_type: str = "faiss"): vector_store_type: str = "faiss",
embed_model: str = "m3e-base",
):
# Create selected knowledge base # Create selected knowledge base
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -19,7 +21,8 @@ async def create_kb(knowledge_base_name: str,
if KnowledgeBase.exists(knowledge_base_name): if KnowledgeBase.exists(knowledge_base_name):
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KnowledgeBase(knowledge_base_name=knowledge_base_name, kb = KnowledgeBase(knowledge_base_name=knowledge_base_name,
vector_store_type=vector_store_type) vector_store_type=vector_store_type,
embed_model=embed_model)
kb.create() kb.create()
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")

View File

@ -7,7 +7,7 @@ from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_
get_file_path, refresh_vs_cache, get_vs_path) get_file_path, refresh_vs_cache, get_vs_path)
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import json import json
from server.knowledge_base.knowledge_file import KnowledgeFile from server.knowledge_base import KnowledgeFile, KnowledgeBase
async def list_docs(knowledge_base_name: str): async def list_docs(knowledge_base_name: str):
@ -59,8 +59,8 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
kb_file = KnowledgeFile(filename=file.filename, kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name) knowledge_base_name=knowledge_base_name)
kb_file.file2text() kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
kb_file.docs2vs() kb.add_file(kb_file.file2text())
return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}") return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}")
@ -108,20 +108,20 @@ async def recreate_vector_store(knowledge_base_name: str):
recreate vector store from the content. recreate vector store from the content.
this is usefull when user can copy files to content folder directly instead of upload through network. this is usefull when user can copy files to content folder directly instead of upload through network.
''' '''
async def output(kb): async def output(kb_name):
vs_path = get_vs_path(kb) vs_path = get_vs_path(kb_name)
if os.path.isdir(vs_path): if os.path.isdir(vs_path):
shutil.rmtree(vs_path) shutil.rmtree(vs_path)
os.mkdir(vs_path) os.mkdir(vs_path)
print(f"start to recreate vectore in {vs_path}") print(f"start to recreate vectore in {vs_path}")
docs = (await list_docs(kb)).data docs = (await list_docs(kb_name)).data
for i, filename in enumerate(docs): for i, filename in enumerate(docs):
kb_file = KnowledgeFile(filename=filename, kb_file = KnowledgeFile(filename=filename,
knowledge_base_name=kb) knowledge_base_name=kb_name)
print(f"processing {get_file_path(kb, filename)} to vector store.") print(f"processing {kb_file.filepath} to vector store.")
kb_file.file2text() kb = KnowledgeBase.load(knowledge_base_name=kb_name)
kb_file.docs2vs() kb.add_file(kb_file.file2text())
yield json.dumps({ yield json.dumps({
"total": len(docs), "total": len(docs),
"finished": i + 1, "finished": i + 1,

View File

@ -1,9 +1,16 @@
from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path)
import os import os
import sqlite3 import sqlite3
from configs.model_config import KB_ROOT_PATH
import datetime import datetime
import shutil import shutil
from langchain.vectorstores import FAISS
from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path,
refresh_vs_cache, load_embeddings)
from configs.model_config import (KB_ROOT_PATH, embedding_model_dict, EMBEDDING_MODEL,
EMBEDDING_DEVICE)
from server.utils import torch_gc
from typing import List
from langchain.docstore.document import Document
SUPPORTED_VS_TYPES = ["faiss", "milvus"] SUPPORTED_VS_TYPES = ["faiss", "milvus"]
DB_ROOT = os.path.join(KB_ROOT_PATH, "info.db") DB_ROOT = os.path.join(KB_ROOT_PATH, "info.db")
@ -22,7 +29,8 @@ def list_kbs_from_db():
conn.close() conn.close()
return kbs return kbs
def add_kb_to_db(kb_name, vs_type):
def add_kb_to_db(kb_name, vs_type, embed_model):
conn = sqlite3.connect(DB_ROOT) conn = sqlite3.connect(DB_ROOT)
c = conn.cursor() c = conn.cursor()
# Create table # Create table
@ -30,13 +38,15 @@ def add_kb_to_db(kb_name, vs_type):
(ID INTEGER PRIMARY KEY AUTOINCREMENT, (ID INTEGER PRIMARY KEY AUTOINCREMENT,
KB_NAME TEXT, KB_NAME TEXT,
VS_TYPE TEXT, VS_TYPE TEXT,
EMBED_MODEL TEXT,
FILE_COUNT INTEGER, FILE_COUNT INTEGER,
CREATE_TIME DATETIME) ''') CREATE_TIME DATETIME) ''')
# Insert a row of data # Insert a row of data
c.execute(f"""INSERT INTO KNOWLEDGE_BASE c.execute(f"""INSERT INTO KNOWLEDGE_BASE
(KB_NAME, VS_TYPE, FILE_COUNT, CREATE_TIME) (KB_NAME, VS_TYPE, EMBED_MODEL, FILE_COUNT, CREATE_TIME)
VALUES VALUES
('{kb_name}','{vs_type}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") ('{kb_name}','{vs_type}','{embed_model}',
0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
conn.commit() conn.commit()
conn.close() conn.close()
@ -56,17 +66,17 @@ def kb_exists(kb_name):
def load_kb_from_db(kb_name): def load_kb_from_db(kb_name):
conn = sqlite3.connect(DB_ROOT) conn = sqlite3.connect(DB_ROOT)
c = conn.cursor() c = conn.cursor()
c.execute(f'''SELECT KB_NAME, VS_TYPE c.execute(f'''SELECT KB_NAME, VS_TYPE, EMBED_MODEL
FROM KNOWLEDGE_BASE FROM KNOWLEDGE_BASE
WHERE KB_NAME="{kb_name}" ''') WHERE KB_NAME="{kb_name}" ''')
resp = c.fetchone() resp = c.fetchone()
if resp: if resp:
kb_name, vs_type = resp kb_name, vs_type, embed_model = resp
else: else:
kb_name, vs_type = None, None kb_name, vs_type, embed_model = None, None, None
conn.commit() conn.commit()
conn.close() conn.close()
return kb_name, vs_type return kb_name, vs_type, embed_model
def delete_kb_from_db(kb_name): def delete_kb_from_db(kb_name):
@ -84,11 +94,15 @@ class KnowledgeBase:
def __init__(self, def __init__(self,
knowledge_base_name: str, knowledge_base_name: str,
vector_store_type: str, vector_store_type: str,
embed_model: str,
): ):
self.kb_name = knowledge_base_name self.kb_name = knowledge_base_name
if vector_store_type not in SUPPORTED_VS_TYPES: if vector_store_type not in SUPPORTED_VS_TYPES:
raise ValueError(f"暂未支持向量库类型 {vector_store_type}") raise ValueError(f"暂未支持向量库类型 {vector_store_type}")
self.vs_type = vector_store_type self.vs_type = vector_store_type
if embed_model not in embedding_model_dict.keys():
raise ValueError(f"暂未支持embedding模型 {embed_model}")
self.embed_model = embed_model
self.kb_path = get_kb_path(self.kb_name) self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name)
if self.vs_type in ["faiss"]: if self.vs_type in ["faiss"]:
@ -102,12 +116,31 @@ class KnowledgeBase:
if self.vs_type in ["faiss"]: if self.vs_type in ["faiss"]:
if not os.path.exists(self.vs_path): if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path) os.makedirs(self.vs_path)
add_kb_to_db(self.kb_name, self.vs_type) add_kb_to_db(self.kb_name, self.vs_type, self.embed_model)
elif self.vs_type in ["milvus"]: elif self.vs_type in ["milvus"]:
# TODO: 创建milvus库 # TODO: 创建milvus库
pass pass
return True return True
def add_file(self, docs: List[Document]):
vs_path = get_vs_path(self.kb_name)
embeddings = load_embeddings(embedding_model_dict[self.embed_model], EMBEDDING_DEVICE)
if self.vs_type in ["faiss"]:
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
vector_store = FAISS.load_local(vs_path, embeddings)
vector_store.add_documents(docs)
torch_gc()
else:
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
torch_gc()
vector_store.save_local(vs_path)
refresh_vs_cache(self.kb_name)
elif self.vs_type in ["milvus"]:
# TODO: 向milvus库中增加文件
pass
@classmethod @classmethod
def exists(cls, def exists(cls,
knowledge_base_name: str): knowledge_base_name: str):
@ -116,8 +149,8 @@ class KnowledgeBase:
@classmethod @classmethod
def load(cls, def load(cls,
knowledge_base_name: str): knowledge_base_name: str):
kb_name, vs_type = load_kb_from_db(knowledge_base_name) kb_name, vs_type, embed_model = load_kb_from_db(knowledge_base_name)
return cls(kb_name, vs_type) return cls(kb_name, vs_type, embed_model)
@classmethod @classmethod
def delete(cls, def delete(cls,

View File

@ -4,6 +4,7 @@ from server.knowledge_base.utils import (get_file_path, get_vs_path,
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE) from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
from langchain.vectorstores import FAISS from langchain.vectorstores import FAISS
from server.utils import torch_gc from server.utils import torch_gc
from server.knowledge_base import KnowledgeBase
class KnowledgeFile: class KnowledgeFile:
@ -12,7 +13,7 @@ class KnowledgeFile:
filename: str, filename: str,
knowledge_base_name: str knowledge_base_name: str
): ):
self.knowledge_base_name = knowledge_base_name self.kb = KnowledgeBase.load(knowledge_base_name)
self.knowledge_base_type = "faiss" self.knowledge_base_type = "faiss"
self.filename = filename self.filename = filename
self.ext = os.path.splitext(filename)[-1] self.ext = os.path.splitext(filename)[-1]
@ -28,12 +29,12 @@ class KnowledgeFile:
from langchain.text_splitter import CharacterTextSplitter from langchain.text_splitter import CharacterTextSplitter
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200) text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200)
self.docs = loader.load_and_split(text_splitter) return loader.load_and_split(text_splitter)
return True
def docs2vs(self): def docs2vs(self):
vs_path = get_vs_path(self.knowledge_base_name) vs_path = get_vs_path(self.kb.kb_name)
embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE) embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE)
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
vector_store = FAISS.load_local(vs_path, embeddings) vector_store = FAISS.load_local(vs_path, embeddings)
vector_store.add_documents(self.docs) vector_store.add_documents(self.docs)
@ -44,5 +45,5 @@ class KnowledgeFile:
vector_store = FAISS.from_documents(self.docs, embeddings) # docs 为Document列表 vector_store = FAISS.from_documents(self.docs, embeddings) # docs 为Document列表
torch_gc() torch_gc()
vector_store.save_local(vs_path) vector_store.save_local(vs_path)
refresh_vs_cache(self.knowledge_base_name) refresh_vs_cache(self.kb.kb_name)
return True return True