update class method of KnowledgeBase and KnowledgeFile
This commit is contained in:
parent
6c7adfbaeb
commit
206261cd0c
|
|
@ -10,7 +10,9 @@ async def list_kbs():
|
|||
|
||||
|
||||
async def create_kb(knowledge_base_name: str,
|
||||
vector_store_type: str = "faiss"):
|
||||
vector_store_type: str = "faiss",
|
||||
embed_model: str = "m3e-base",
|
||||
):
|
||||
# Create selected knowledge base
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
|
@ -19,7 +21,8 @@ async def create_kb(knowledge_base_name: str,
|
|||
if KnowledgeBase.exists(knowledge_base_name):
|
||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||
kb = KnowledgeBase(knowledge_base_name=knowledge_base_name,
|
||||
vector_store_type=vector_store_type)
|
||||
vector_store_type=vector_store_type,
|
||||
embed_model=embed_model)
|
||||
kb.create()
|
||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_
|
|||
get_file_path, refresh_vs_cache, get_vs_path)
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||
from server.knowledge_base import KnowledgeFile, KnowledgeBase
|
||||
|
||||
|
||||
async def list_docs(knowledge_base_name: str):
|
||||
|
|
@ -59,8 +59,8 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||
|
||||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb_file.file2text()
|
||||
kb_file.docs2vs()
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
kb.add_file(kb_file.file2text())
|
||||
|
||||
return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}")
|
||||
|
||||
|
|
@ -108,20 +108,20 @@ async def recreate_vector_store(knowledge_base_name: str):
|
|||
recreate vector store from the content.
|
||||
this is usefull when user can copy files to content folder directly instead of upload through network.
|
||||
'''
|
||||
async def output(kb):
|
||||
vs_path = get_vs_path(kb)
|
||||
async def output(kb_name):
|
||||
vs_path = get_vs_path(kb_name)
|
||||
if os.path.isdir(vs_path):
|
||||
shutil.rmtree(vs_path)
|
||||
os.mkdir(vs_path)
|
||||
print(f"start to recreate vectore in {vs_path}")
|
||||
|
||||
docs = (await list_docs(kb)).data
|
||||
docs = (await list_docs(kb_name)).data
|
||||
for i, filename in enumerate(docs):
|
||||
kb_file = KnowledgeFile(filename=filename,
|
||||
knowledge_base_name=kb)
|
||||
print(f"processing {get_file_path(kb, filename)} to vector store.")
|
||||
kb_file.file2text()
|
||||
kb_file.docs2vs()
|
||||
knowledge_base_name=kb_name)
|
||||
print(f"processing {kb_file.filepath} to vector store.")
|
||||
kb = KnowledgeBase.load(knowledge_base_name=kb_name)
|
||||
kb.add_file(kb_file.file2text())
|
||||
yield json.dumps({
|
||||
"total": len(docs),
|
||||
"finished": i + 1,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,16 @@
|
|||
from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path)
|
||||
import os
|
||||
import sqlite3
|
||||
from configs.model_config import KB_ROOT_PATH
|
||||
import datetime
|
||||
import shutil
|
||||
from langchain.vectorstores import FAISS
|
||||
from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path,
|
||||
refresh_vs_cache, load_embeddings)
|
||||
from configs.model_config import (KB_ROOT_PATH, embedding_model_dict, EMBEDDING_MODEL,
|
||||
EMBEDDING_DEVICE)
|
||||
from server.utils import torch_gc
|
||||
from typing import List
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
SUPPORTED_VS_TYPES = ["faiss", "milvus"]
|
||||
DB_ROOT = os.path.join(KB_ROOT_PATH, "info.db")
|
||||
|
|
@ -22,7 +29,8 @@ def list_kbs_from_db():
|
|||
conn.close()
|
||||
return kbs
|
||||
|
||||
def add_kb_to_db(kb_name, vs_type):
|
||||
|
||||
def add_kb_to_db(kb_name, vs_type, embed_model):
|
||||
conn = sqlite3.connect(DB_ROOT)
|
||||
c = conn.cursor()
|
||||
# Create table
|
||||
|
|
@ -30,13 +38,15 @@ def add_kb_to_db(kb_name, vs_type):
|
|||
(ID INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
KB_NAME TEXT,
|
||||
VS_TYPE TEXT,
|
||||
EMBED_MODEL TEXT,
|
||||
FILE_COUNT INTEGER,
|
||||
CREATE_TIME DATETIME) ''')
|
||||
# Insert a row of data
|
||||
c.execute(f"""INSERT INTO KNOWLEDGE_BASE
|
||||
(KB_NAME, VS_TYPE, FILE_COUNT, CREATE_TIME)
|
||||
(KB_NAME, VS_TYPE, EMBED_MODEL, FILE_COUNT, CREATE_TIME)
|
||||
VALUES
|
||||
('{kb_name}','{vs_type}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
||||
('{kb_name}','{vs_type}','{embed_model}',
|
||||
0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
|
@ -56,17 +66,17 @@ def kb_exists(kb_name):
|
|||
def load_kb_from_db(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT)
|
||||
c = conn.cursor()
|
||||
c.execute(f'''SELECT KB_NAME, VS_TYPE
|
||||
c.execute(f'''SELECT KB_NAME, VS_TYPE, EMBED_MODEL
|
||||
FROM KNOWLEDGE_BASE
|
||||
WHERE KB_NAME="{kb_name}" ''')
|
||||
resp = c.fetchone()
|
||||
if resp:
|
||||
kb_name, vs_type = resp
|
||||
kb_name, vs_type, embed_model = resp
|
||||
else:
|
||||
kb_name, vs_type = None, None
|
||||
kb_name, vs_type, embed_model = None, None, None
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return kb_name, vs_type
|
||||
return kb_name, vs_type, embed_model
|
||||
|
||||
|
||||
def delete_kb_from_db(kb_name):
|
||||
|
|
@ -84,11 +94,15 @@ class KnowledgeBase:
|
|||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
vector_store_type: str,
|
||||
embed_model: str,
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
if vector_store_type not in SUPPORTED_VS_TYPES:
|
||||
raise ValueError(f"暂未支持向量库类型 {vector_store_type}")
|
||||
self.vs_type = vector_store_type
|
||||
if embed_model not in embedding_model_dict.keys():
|
||||
raise ValueError(f"暂未支持embedding模型 {embed_model}")
|
||||
self.embed_model = embed_model
|
||||
self.kb_path = get_kb_path(self.kb_name)
|
||||
self.doc_path = get_doc_path(self.kb_name)
|
||||
if self.vs_type in ["faiss"]:
|
||||
|
|
@ -102,12 +116,31 @@ class KnowledgeBase:
|
|||
if self.vs_type in ["faiss"]:
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
add_kb_to_db(self.kb_name, self.vs_type)
|
||||
add_kb_to_db(self.kb_name, self.vs_type, self.embed_model)
|
||||
elif self.vs_type in ["milvus"]:
|
||||
# TODO: 创建milvus库
|
||||
pass
|
||||
return True
|
||||
|
||||
def add_file(self, docs: List[Document]):
|
||||
vs_path = get_vs_path(self.kb_name)
|
||||
embeddings = load_embeddings(embedding_model_dict[self.embed_model], EMBEDDING_DEVICE)
|
||||
if self.vs_type in ["faiss"]:
|
||||
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
|
||||
vector_store = FAISS.load_local(vs_path, embeddings)
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
else:
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store.save_local(vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
elif self.vs_type in ["milvus"]:
|
||||
# TODO: 向milvus库中增加文件
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def exists(cls,
|
||||
knowledge_base_name: str):
|
||||
|
|
@ -116,8 +149,8 @@ class KnowledgeBase:
|
|||
@classmethod
|
||||
def load(cls,
|
||||
knowledge_base_name: str):
|
||||
kb_name, vs_type = load_kb_from_db(knowledge_base_name)
|
||||
return cls(kb_name, vs_type)
|
||||
kb_name, vs_type, embed_model = load_kb_from_db(knowledge_base_name)
|
||||
return cls(kb_name, vs_type, embed_model)
|
||||
|
||||
@classmethod
|
||||
def delete(cls,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from server.knowledge_base.utils import (get_file_path, get_vs_path,
|
|||
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
|
||||
from langchain.vectorstores import FAISS
|
||||
from server.utils import torch_gc
|
||||
from server.knowledge_base import KnowledgeBase
|
||||
|
||||
|
||||
class KnowledgeFile:
|
||||
|
|
@ -12,7 +13,7 @@ class KnowledgeFile:
|
|||
filename: str,
|
||||
knowledge_base_name: str
|
||||
):
|
||||
self.knowledge_base_name = knowledge_base_name
|
||||
self.kb = KnowledgeBase.load(knowledge_base_name)
|
||||
self.knowledge_base_type = "faiss"
|
||||
self.filename = filename
|
||||
self.ext = os.path.splitext(filename)[-1]
|
||||
|
|
@ -28,12 +29,12 @@ class KnowledgeFile:
|
|||
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200)
|
||||
self.docs = loader.load_and_split(text_splitter)
|
||||
return True
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
||||
def docs2vs(self):
|
||||
vs_path = get_vs_path(self.knowledge_base_name)
|
||||
vs_path = get_vs_path(self.kb.kb_name)
|
||||
embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE)
|
||||
|
||||
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
|
||||
vector_store = FAISS.load_local(vs_path, embeddings)
|
||||
vector_store.add_documents(self.docs)
|
||||
|
|
@ -44,5 +45,5 @@ class KnowledgeFile:
|
|||
vector_store = FAISS.from_documents(self.docs, embeddings) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store.save_local(vs_path)
|
||||
refresh_vs_cache(self.knowledge_base_name)
|
||||
refresh_vs_cache(self.kb.kb_name)
|
||||
return True
|
||||
Loading…
Reference in New Issue