1. add add_doc and list_docs to KnowledgeBase class
2. add DB_ROOT_PATH to model_config.py.example
This commit is contained in:
parent
313e590961
commit
3f045cedb9
|
|
@ -239,6 +239,9 @@ if not os.path.exists(LOG_PATH):
|
|||
# 知识库默认存储路径
|
||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
||||
|
||||
# 数据库默认存储路径
|
||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||
|
||||
# 缓存向量库数量
|
||||
CACHED_VS_NUM = 1
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_
|
|||
get_file_path, refresh_vs_cache, get_vs_path)
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from server.knowledge_base import KnowledgeFile, KnowledgeBase
|
||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
||||
|
||||
|
||||
async def list_docs(knowledge_base_name: str):
|
||||
|
|
@ -16,17 +17,10 @@ async def list_docs(knowledge_base_name: str):
|
|||
|
||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||
kb_path = get_kb_path(knowledge_base_name)
|
||||
local_doc_folder = get_doc_path(knowledge_base_name)
|
||||
if not os.path.exists(kb_path):
|
||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||
if not os.path.exists(local_doc_folder):
|
||||
all_doc_names = []
|
||||
else:
|
||||
all_doc_names = [
|
||||
doc
|
||||
for doc in os.listdir(local_doc_folder)
|
||||
if os.path.isfile(os.path.join(local_doc_folder, doc))
|
||||
]
|
||||
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs()
|
||||
return ListResponse(data=all_doc_names)
|
||||
|
||||
|
||||
|
|
@ -60,7 +54,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
kb.add_file(kb_file.file2text())
|
||||
kb.add_file(kb_file)
|
||||
|
||||
return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}")
|
||||
|
||||
|
|
@ -121,7 +115,7 @@ async def recreate_vector_store(knowledge_base_name: str):
|
|||
knowledge_base_name=kb_name)
|
||||
print(f"processing {kb_file.filepath} to vector store.")
|
||||
kb = KnowledgeBase.load(knowledge_base_name=kb_name)
|
||||
kb.add_file(kb_file.file2text())
|
||||
kb.add_file(kb_file)
|
||||
yield json.dumps({
|
||||
"total": len(docs),
|
||||
"finished": i + 1,
|
||||
|
|
|
|||
|
|
@ -5,23 +5,27 @@ import shutil
|
|||
from langchain.vectorstores import FAISS
|
||||
from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path,
|
||||
refresh_vs_cache, load_embeddings)
|
||||
from configs.model_config import (KB_ROOT_PATH, embedding_model_dict, EMBEDDING_MODEL,
|
||||
EMBEDDING_DEVICE)
|
||||
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL,
|
||||
EMBEDDING_DEVICE, DB_ROOT_PATH)
|
||||
from server.utils import torch_gc
|
||||
from typing import List
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||
|
||||
SUPPORTED_VS_TYPES = ["faiss", "milvus"]
|
||||
DB_ROOT = os.path.join(KB_ROOT_PATH, "info.db")
|
||||
|
||||
|
||||
def list_kbs_from_db():
|
||||
conn = sqlite3.connect(DB_ROOT)
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute(f'''SELECT KB_NAME
|
||||
FROM KNOWLEDGE_BASE
|
||||
WHERE FILE_COUNT>0 ''')
|
||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
kb_name TEXT,
|
||||
vs_type TEXT,
|
||||
embed_model TEXT,
|
||||
file_count INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute(f'''SELECT kb_name
|
||||
FROM knowledge_base
|
||||
WHERE file_count>0 ''')
|
||||
kbs = [i[0] for i in c.fetchall() if i]
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
|
@ -29,7 +33,7 @@ def list_kbs_from_db():
|
|||
|
||||
|
||||
def add_kb_to_db(kb_name, vs_type, embed_model):
|
||||
conn = sqlite3.connect(DB_ROOT)
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
# Create table
|
||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
||||
|
|
@ -50,8 +54,15 @@ def add_kb_to_db(kb_name, vs_type, embed_model):
|
|||
|
||||
|
||||
def kb_exists(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT)
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
kb_name TEXT,
|
||||
vs_type TEXT,
|
||||
embed_model TEXT,
|
||||
file_count INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute(f'''SELECT COUNT(*)
|
||||
FROM knowledge_base
|
||||
WHERE kb_name="{kb_name}" ''')
|
||||
|
|
@ -62,8 +73,15 @@ def kb_exists(kb_name):
|
|||
|
||||
|
||||
def load_kb_from_db(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT)
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
kb_name TEXT,
|
||||
vs_type TEXT,
|
||||
embed_model TEXT,
|
||||
file_count INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute(f'''SELECT kb_name, vs_type, embed_model
|
||||
FROM knowledge_base
|
||||
WHERE kb_name="{kb_name}" ''')
|
||||
|
|
@ -78,16 +96,82 @@ def load_kb_from_db(kb_name):
|
|||
|
||||
|
||||
def delete_kb_from_db(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT)
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
# delete kb from table knowledge_base
|
||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
kb_name TEXT,
|
||||
vs_type TEXT,
|
||||
embed_model TEXT,
|
||||
file_count INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute(f'''DELETE
|
||||
FROM knowledge_base
|
||||
WHERE kb_name="{kb_name}" ''')
|
||||
# delete files in kb from table knowledge_files
|
||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_name TEXT,
|
||||
file_ext TEXT,
|
||||
kb_name TEXT,
|
||||
document_loader_name TEXT,
|
||||
text_splitter_name TEXT,
|
||||
file_version INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
# Insert a row of data
|
||||
c.execute(f"""DELETE
|
||||
FROM knowledge_files
|
||||
WHERE kb_name="{kb_name}"
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return True
|
||||
|
||||
|
||||
def list_docs_from_db(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_name TEXT,
|
||||
file_ext TEXT,
|
||||
kb_name TEXT,
|
||||
document_loader_name TEXT,
|
||||
text_splitter_name TEXT,
|
||||
file_version INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute(f'''SELECT file_name
|
||||
FROM knowledge_files
|
||||
WHERE kb_name="{kb_name}" ''')
|
||||
kbs = [i[0] for i in c.fetchall() if i]
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return kbs
|
||||
|
||||
def add_file_to_db(kb_file: KnowledgeFile):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
# Create table
|
||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_name TEXT,
|
||||
file_ext TEXT,
|
||||
kb_name TEXT,
|
||||
document_loader_name TEXT,
|
||||
text_splitter_name TEXT,
|
||||
file_version INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
# Insert a row of data
|
||||
c.execute(f"""INSERT INTO knowledge_files
|
||||
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time)
|
||||
VALUES
|
||||
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}',
|
||||
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
class KnowledgeBase:
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
|
|
@ -120,9 +204,10 @@ class KnowledgeBase:
|
|||
pass
|
||||
return True
|
||||
|
||||
def add_file(self, docs: List[Document]):
|
||||
def add_doc(self, kb_file: KnowledgeFile):
|
||||
docs = kb_file.file2text()
|
||||
vs_path = get_vs_path(self.kb_name)
|
||||
embeddings = load_embeddings(embedding_model_dict[self.embed_model], EMBEDDING_DEVICE)
|
||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||
if self.vs_type in ["faiss"]:
|
||||
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
|
||||
vector_store = FAISS.load_local(vs_path, embeddings)
|
||||
|
|
@ -134,11 +219,15 @@ class KnowledgeBase:
|
|||
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store.save_local(vs_path)
|
||||
add_file_to_db(kb_file)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
elif self.vs_type in ["milvus"]:
|
||||
# TODO: 向milvus库中增加文件
|
||||
pass
|
||||
|
||||
def list_docs(self):
|
||||
return list_docs_from_db(self.kb_name)
|
||||
|
||||
@classmethod
|
||||
def exists(cls,
|
||||
knowledge_base_name: str):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import os.path
|
||||
from server.knowledge_base.utils import (get_file_path)
|
||||
from server.knowledge_base import KnowledgeBase
|
||||
import sys
|
||||
|
||||
|
||||
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
|
||||
'.rtf', '.txt', '.xml',
|
||||
'.doc', '.docx', '.epub', '.odt', '.pdf',
|
||||
|
|
@ -23,19 +23,23 @@ class KnowledgeFile:
|
|||
filename: str,
|
||||
knowledge_base_name: str
|
||||
):
|
||||
self.kb = KnowledgeBase.load(knowledge_base_name)
|
||||
self.kb_name = knowledge_base_name
|
||||
self.filename = filename
|
||||
self.ext = os.path.splitext(filename)[-1]
|
||||
if self.ext not in SUPPORTED_EXTS:
|
||||
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
||||
self.filepath = get_file_path(knowledge_base_name, filename)
|
||||
self.docs = None
|
||||
self.loader_class_name = get_LoaderClass(self.ext)
|
||||
self.document_loader_name = get_LoaderClass(self.ext)
|
||||
|
||||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
self.text_splitter_name = "CharacterTextSplitter"
|
||||
|
||||
def file2text(self):
|
||||
LoaderClass = getattr(sys.modules['langchain.document_loaders'], self.loader_class_name)
|
||||
loader = LoaderClass(self.filepath)
|
||||
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
|
||||
loader = DocumentLoader(self.filepath)
|
||||
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200)
|
||||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name)
|
||||
text_splitter = TextSplitter(chunk_size=500, chunk_overlap=200)
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
|
|
|||
Loading…
Reference in New Issue