增加数据库字段,重建知识库使用多线程 (#1280)
* close #1172: 给webui_page/utils添加一些log信息,方便定位错误 * 修复:重建知识库时页面未实时显示进度 * skip model_worker running when using online model api such as chatgpt * 修改知识库管理相关内容: 1.KnowledgeFileModel增加3个字段:file_mtime(文件修改时间),file_size(文件大小),custom_docs(是否使用自定义docs)。为后面比对上传文件做准备。 2.给所有String字段加上长度,防止mysql建表错误(pr#1177) 3.统一[faiss/milvus/pgvector]_kb_service.add_doc接口,使其支持自定义docs 4.为faiss_kb_service增加一些方法,便于调用 5.为KnowledgeFile增加一些方法,便于获取文件信息,缓存file2text的结果。 * 修复/chat/fastchat无法流式输出的问题 * 新增功能: 1、KnowledgeFileModel增加"docs_count"字段,代表该文件加载到向量库中的Document数量,并在WEBUI中进行展示。 2、重建知识库`python init_database.py --recreate-vs`支持多线程。 其它: 统一代码中知识库相关函数用词:file代表一个文件名称或路径,doc代表langchain加载后的Document。部分与API接口有关或含义重叠的函数暂未修改。 --------- Co-authored-by: liunux4odoo <liunux@qq.com>, hongkong9771
This commit is contained in:
parent
89e3e9a691
commit
3acbf4d5d1
|
|
@ -5,3 +5,5 @@ logs
|
|||
__pycache__/
|
||||
knowledge_base/
|
||||
configs/*.py
|
||||
.vscode/
|
||||
.pytest_cache/
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from server.knowledge_base.migrate import create_tables, folder2db, recreate_all_vs, list_kbs_from_folder
|
||||
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, recreate_all_vs, list_kbs_from_folder
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
from startup import dump_server_info
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -25,13 +26,19 @@ if __name__ == "__main__":
|
|||
|
||||
dump_server_info()
|
||||
|
||||
create_tables()
|
||||
print("database talbes created")
|
||||
start_time = datetime.now()
|
||||
|
||||
if args.recreate_vs:
|
||||
reset_tables()
|
||||
print("database talbes reseted")
|
||||
print("recreating all vector stores")
|
||||
recreate_all_vs()
|
||||
else:
|
||||
create_tables()
|
||||
print("database talbes created")
|
||||
print("filling kb infos to database")
|
||||
for kb in list_kbs_from_folder():
|
||||
folder2db(kb, "fill_info_only")
|
||||
|
||||
end_time = datetime.now()
|
||||
print(f"总计用时: {end_time-start_time}")
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from starlette.responses import RedirectResponse
|
|||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||
search_engine_chat)
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
|
||||
update_doc, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore)
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||
|
|
@ -84,11 +84,11 @@ def create_app():
|
|||
summary="删除知识库"
|
||||
)(delete_kb)
|
||||
|
||||
app.get("/knowledge_base/list_docs",
|
||||
app.get("/knowledge_base/list_files",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=ListResponse,
|
||||
summary="获取知识库内的文件列表"
|
||||
)(list_docs)
|
||||
)(list_files)
|
||||
|
||||
app.post("/knowledge_base/search_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
|
|
|
|||
|
|
@ -29,23 +29,22 @@ async def openai_chat(msg: OpenAiChatMsgIn):
|
|||
print(f"{openai.api_base=}")
|
||||
print(msg)
|
||||
|
||||
async def get_response(msg):
|
||||
def get_response(msg):
|
||||
data = msg.dict()
|
||||
data["streaming"] = True
|
||||
data.pop("stream")
|
||||
|
||||
try:
|
||||
response = openai.ChatCompletion.create(**data)
|
||||
if msg.stream:
|
||||
for chunk in response.choices[0].message.content:
|
||||
print(chunk)
|
||||
yield chunk
|
||||
for data in response:
|
||||
if choices := data.choices:
|
||||
if chunk := choices[0].get("delta", {}).get("content"):
|
||||
print(chunk, end="", flush=True)
|
||||
yield chunk
|
||||
else:
|
||||
answer = ""
|
||||
for chunk in response.choices[0].message.content:
|
||||
answer += chunk
|
||||
print(answer)
|
||||
yield(answer)
|
||||
if response.choices:
|
||||
answer = response.choices[0].message.content
|
||||
print(answer)
|
||||
yield(answer)
|
||||
except Exception as e:
|
||||
print(type(e))
|
||||
logger.error(e)
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ class KnowledgeBaseModel(Base):
|
|||
"""
|
||||
__tablename__ = 'knowledge_base'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID')
|
||||
kb_name = Column(String, comment='知识库名称')
|
||||
vs_type = Column(String, comment='嵌入模型类型')
|
||||
embed_model = Column(String, comment='嵌入模型名称')
|
||||
kb_name = Column(String(50), comment='知识库名称')
|
||||
vs_type = Column(String(50), comment='向量库类型')
|
||||
embed_model = Column(String(50), comment='嵌入模型名称')
|
||||
file_count = Column(Integer, default=0, comment='文件数量')
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from sqlalchemy import Column, Integer, String, DateTime, func
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, func
|
||||
|
||||
from server.db.base import Base
|
||||
|
||||
|
|
@ -9,12 +9,16 @@ class KnowledgeFileModel(Base):
|
|||
"""
|
||||
__tablename__ = 'knowledge_file'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识文件ID')
|
||||
file_name = Column(String, comment='文件名')
|
||||
file_ext = Column(String, comment='文件扩展名')
|
||||
kb_name = Column(String, comment='所属知识库名称')
|
||||
document_loader_name = Column(String, comment='文档加载器名称')
|
||||
text_splitter_name = Column(String, comment='文本分割器名称')
|
||||
file_name = Column(String(255), comment='文件名')
|
||||
file_ext = Column(String(10), comment='文件扩展名')
|
||||
kb_name = Column(String(50), comment='所属知识库名称')
|
||||
document_loader_name = Column(String(50), comment='文档加载器名称')
|
||||
text_splitter_name = Column(String(50), comment='文本分割器名称')
|
||||
file_version = Column(Integer, default=1, comment='文件版本')
|
||||
file_mtime = Column(Float, default=0.0, comment="文件修改时间")
|
||||
file_size = Column(Integer, default=0, comment="文件大小")
|
||||
custom_docs = Column(Boolean, default=False, comment="是否自定义docs")
|
||||
docs_count = Column(Integer, default=0, comment="切分文档数量")
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
|||
|
|
@ -5,20 +5,37 @@ from server.knowledge_base.utils import KnowledgeFile
|
|||
|
||||
|
||||
@with_session
|
||||
def list_docs_from_db(session, kb_name):
|
||||
def count_files_from_db(session, kb_name: str) -> int:
|
||||
return session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).count()
|
||||
|
||||
|
||||
@with_session
|
||||
def list_files_from_db(session, kb_name):
|
||||
files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all()
|
||||
docs = [f.file_name for f in files]
|
||||
return docs
|
||||
|
||||
|
||||
@with_session
|
||||
def add_doc_to_db(session, kb_file: KnowledgeFile):
|
||||
def add_file_to_db(session,
|
||||
kb_file: KnowledgeFile,
|
||||
docs_count: int = 0,
|
||||
custom_docs: bool = False,):
|
||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
||||
if kb:
|
||||
# 如果已经存在该文件,则更新文件版本号
|
||||
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
|
||||
kb_name=kb_file.kb_name).first()
|
||||
# 如果已经存在该文件,则更新文件信息与版本号
|
||||
existing_file: KnowledgeFileModel = (session.query(KnowledgeFileModel)
|
||||
.filter_by(file_name=kb_file.filename,
|
||||
kb_name=kb_file.kb_name)
|
||||
.first())
|
||||
mtime = kb_file.get_mtime()
|
||||
size = kb_file.get_size()
|
||||
|
||||
if existing_file:
|
||||
existing_file.file_mtime = mtime
|
||||
existing_file.file_size = size
|
||||
existing_file.docs_count = docs_count
|
||||
existing_file.custom_docs = custom_docs
|
||||
existing_file.file_version += 1
|
||||
# 否则,添加新文件
|
||||
else:
|
||||
|
|
@ -28,6 +45,10 @@ def add_doc_to_db(session, kb_file: KnowledgeFile):
|
|||
kb_name=kb_file.kb_name,
|
||||
document_loader_name=kb_file.document_loader_name,
|
||||
text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter",
|
||||
file_mtime=mtime,
|
||||
file_size=size,
|
||||
docs_count = docs_count,
|
||||
custom_docs=custom_docs,
|
||||
)
|
||||
kb.file_count += 1
|
||||
session.add(new_file)
|
||||
|
|
@ -62,7 +83,7 @@ def delete_files_from_db(session, knowledge_base_name: str):
|
|||
|
||||
|
||||
@with_session
|
||||
def doc_exists(session, kb_file: KnowledgeFile):
|
||||
def file_exists_in_db(session, kb_file: KnowledgeFile):
|
||||
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
|
||||
kb_name=kb_file.kb_name).first()
|
||||
return True if existing_file else False
|
||||
|
|
@ -82,6 +103,10 @@ def get_file_detail(session, kb_name: str, filename: str) -> dict:
|
|||
"document_loader": file.document_loader_name,
|
||||
"text_splitter": file.text_splitter_name,
|
||||
"create_time": file.create_time,
|
||||
"file_mtime": file.file_mtime,
|
||||
"file_size": file.file_size,
|
||||
"custom_docs": file.custom_docs,
|
||||
"docs_count": file.docs_count,
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import urllib
|
|||
from fastapi import File, Form, Body, Query, UploadFile
|
||||
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
|
||||
from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
import json
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
|
|
@ -29,7 +29,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
|||
return data
|
||||
|
||||
|
||||
async def list_docs(
|
||||
async def list_files(
|
||||
knowledge_base_name: str
|
||||
) -> ListResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
|
|
@ -40,7 +40,7 @@ async def list_docs(
|
|||
if kb is None:
|
||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||
else:
|
||||
all_doc_names = kb.list_docs()
|
||||
all_doc_names = kb.list_files()
|
||||
return ListResponse(data=all_doc_names)
|
||||
|
||||
|
||||
|
|
@ -190,7 +190,7 @@ async def recreate_vector_store(
|
|||
else:
|
||||
kb.create_kb()
|
||||
kb.clear_vs()
|
||||
docs = list_docs_from_folder(knowledge_base_name)
|
||||
docs = list_files_from_folder(knowledge_base_name)
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, knowledge_base_name)
|
||||
|
|
|
|||
|
|
@ -13,15 +13,15 @@ from server.db.repository.knowledge_base_repository import (
|
|||
load_kb_from_db, get_kb_detail,
|
||||
)
|
||||
from server.db.repository.knowledge_file_repository import (
|
||||
add_doc_to_db, delete_file_from_db, delete_files_from_db, doc_exists,
|
||||
list_docs_from_db, get_file_detail, delete_file_from_db
|
||||
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
|
||||
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db
|
||||
)
|
||||
|
||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||||
from server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||
list_kbs_from_folder, list_docs_from_folder,
|
||||
list_kbs_from_folder, list_files_from_folder,
|
||||
)
|
||||
from typing import List, Union, Dict
|
||||
|
||||
|
|
@ -74,16 +74,22 @@ class KBService(ABC):
|
|||
status = delete_kb_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
|
||||
"""
|
||||
docs = kb_file.file2text()
|
||||
if docs:
|
||||
custom_docs = True
|
||||
else:
|
||||
docs = kb_file.file2text()
|
||||
custom_docs = False
|
||||
|
||||
if docs:
|
||||
self.delete_doc(kb_file)
|
||||
embeddings = self._load_embeddings()
|
||||
self.do_add_doc(docs, embeddings, **kwargs)
|
||||
status = add_doc_to_db(kb_file)
|
||||
self.do_add_doc(docs, embeddings=embeddings, **kwargs)
|
||||
status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs))
|
||||
else:
|
||||
status = False
|
||||
return status
|
||||
|
|
@ -98,20 +104,24 @@ class KBService(ABC):
|
|||
os.remove(kb_file.filepath)
|
||||
return status
|
||||
|
||||
def update_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||
"""
|
||||
使用content中的文件更新向量库
|
||||
如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
|
||||
"""
|
||||
if os.path.exists(kb_file.filepath):
|
||||
self.delete_doc(kb_file, **kwargs)
|
||||
return self.add_doc(kb_file, **kwargs)
|
||||
return self.add_doc(kb_file, docs=docs, **kwargs)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
filename=file_name))
|
||||
|
||||
def list_docs(self):
|
||||
return list_docs_from_db(self.kb_name)
|
||||
def list_files(self):
|
||||
return list_files_from_db(self.kb_name)
|
||||
|
||||
def count_files(self):
|
||||
return count_files_from_db(self.kb_name)
|
||||
|
||||
def search_docs(self,
|
||||
query: str,
|
||||
|
|
@ -264,25 +274,26 @@ def get_kb_details() -> List[Dict]:
|
|||
return data
|
||||
|
||||
|
||||
def get_kb_doc_details(kb_name: str) -> List[Dict]:
|
||||
def get_kb_file_details(kb_name: str) -> List[Dict]:
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
docs_in_folder = list_docs_from_folder(kb_name)
|
||||
docs_in_db = kb.list_docs()
|
||||
files_in_folder = list_files_from_folder(kb_name)
|
||||
files_in_db = kb.list_files()
|
||||
result = {}
|
||||
|
||||
for doc in docs_in_folder:
|
||||
for doc in files_in_folder:
|
||||
result[doc] = {
|
||||
"kb_name": kb_name,
|
||||
"file_name": doc,
|
||||
"file_ext": os.path.splitext(doc)[-1],
|
||||
"file_version": 0,
|
||||
"document_loader": "",
|
||||
"docs_count": 0,
|
||||
"text_splitter": "",
|
||||
"create_time": None,
|
||||
"in_folder": True,
|
||||
"in_db": False,
|
||||
}
|
||||
for doc in docs_in_db:
|
||||
for doc in files_in_db:
|
||||
doc_detail = get_file_detail(kb_name, doc)
|
||||
if doc_detail:
|
||||
doc_detail["in_db"] = True
|
||||
|
|
|
|||
|
|
@ -13,34 +13,16 @@ from functools import lru_cache
|
|||
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from typing import List
|
||||
from langchain.docstore.document import Document
|
||||
from server.utils import torch_gc
|
||||
|
||||
|
||||
# make HuggingFaceEmbeddings hashable
|
||||
def _embeddings_hash(self):
|
||||
if isinstance(self, HuggingFaceEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, HuggingFaceBgeEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, OpenAIEmbeddings):
|
||||
return hash(self.model)
|
||||
|
||||
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
|
||||
OpenAIEmbeddings.__hash__ = _embeddings_hash
|
||||
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
|
||||
|
||||
@lru_cache(CACHED_VS_NUM)
|
||||
def load_vector_store(
|
||||
def load_faiss_vector_store(
|
||||
knowledge_base_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = EMBEDDING_DEVICE,
|
||||
|
|
@ -86,22 +68,30 @@ class FaissKBService(KBService):
|
|||
def vs_type(self) -> str:
|
||||
return SupportedVSType.FAISS
|
||||
|
||||
@staticmethod
|
||||
def get_vs_path(knowledge_base_name: str):
|
||||
return os.path.join(FaissKBService.get_kb_path(knowledge_base_name), "vector_store")
|
||||
def get_vs_path(self):
|
||||
return os.path.join(self.get_kb_path(), "vector_store")
|
||||
|
||||
@staticmethod
|
||||
def get_kb_path(knowledge_base_name: str):
|
||||
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||
def get_kb_path(self):
|
||||
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
||||
|
||||
def load_vector_store(self):
|
||||
return load_faiss_vector_store(
|
||||
knowledge_base_name=self.kb_name,
|
||||
embed_model=self.embed_model,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
|
||||
)
|
||||
|
||||
def refresh_vs_cache(self):
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
def do_init(self):
|
||||
self.kb_path = FaissKBService.get_kb_path(self.kb_name)
|
||||
self.vs_path = FaissKBService.get_vs_path(self.kb_name)
|
||||
self.kb_path = self.get_kb_path()
|
||||
self.vs_path = self.get_vs_path()
|
||||
|
||||
def do_create_kb(self):
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
load_vector_store(self.kb_name)
|
||||
self.load_vector_store()
|
||||
|
||||
def do_drop_kb(self):
|
||||
self.clear_vs()
|
||||
|
|
@ -113,9 +103,7 @@ class FaissKBService(KBService):
|
|||
score_threshold: float = SCORE_THRESHOLD,
|
||||
embeddings: Embeddings = None,
|
||||
) -> List[Document]:
|
||||
search_index = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||
search_index = self.load_vector_store()
|
||||
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
||||
return docs
|
||||
|
||||
|
|
@ -124,22 +112,18 @@ class FaissKBService(KBService):
|
|||
embeddings: Embeddings,
|
||||
**kwargs,
|
||||
):
|
||||
vector_store = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
||||
vector_store = self.load_vector_store()
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
self.refresh_vs_cache()
|
||||
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile,
|
||||
**kwargs):
|
||||
embeddings = self._load_embeddings()
|
||||
vector_store = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
||||
vector_store = self.load_vector_store()
|
||||
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
|
|
@ -148,14 +132,14 @@ class FaissKBService(KBService):
|
|||
vector_store.delete(ids)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
self.refresh_vs_cache()
|
||||
|
||||
return True
|
||||
|
||||
def do_clear_vs(self):
|
||||
shutil.rmtree(self.vs_path)
|
||||
os.makedirs(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
self.refresh_vs_cache()
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
if super().exist_doc(file_name):
|
||||
|
|
@ -166,10 +150,11 @@ class FaissKBService(KBService):
|
|||
return "in_folder"
|
||||
else:
|
||||
return False
|
||||
if __name__ == '__main__':
|
||||
|
||||
milvusService = FaissKBService("test")
|
||||
milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
milvusService.do_drop_kb()
|
||||
print(milvusService.search_docs("如何启动api服务"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
faissService = FaissKBService("test")
|
||||
faissService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
faissService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
faissService.do_drop_kb()
|
||||
print(faissService.search_docs("如何启动api服务"))
|
||||
|
|
@ -47,24 +47,15 @@ class MilvusKBService(KBService):
|
|||
self._load_milvus()
|
||||
|
||||
def do_drop_kb(self):
|
||||
self.milvus.col.drop()
|
||||
if self.milvus.col:
|
||||
self.milvus.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
|
||||
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
"""
|
||||
docs = kb_file.file2text()
|
||||
self.milvus.add_documents(docs)
|
||||
from server.db.repository.knowledge_file_repository import add_doc_to_db
|
||||
status = add_doc_to_db(kb_file)
|
||||
return status
|
||||
|
||||
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
|
||||
pass
|
||||
self.milvus.add_documents(docs)
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
|
|
|
|||
|
|
@ -47,23 +47,12 @@ class PGKBService(KBService):
|
|||
connect.commit()
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||
# todo: support score threshold
|
||||
self._load_pg_vector(embeddings=embeddings)
|
||||
return score_threshold_process(score_threshold, top_k,
|
||||
self.pg_vector.similarity_search_with_score(query, top_k))
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
"""
|
||||
docs = kb_file.file2text()
|
||||
self.pg_vector.add_documents(docs)
|
||||
from server.db.repository.knowledge_file_repository import add_doc_to_db
|
||||
status = add_doc_to_db(kb_file)
|
||||
return status
|
||||
|
||||
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
|
||||
pass
|
||||
self.pg_vector.add_documents(docs)
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
with self.pg_vector.connect() as connect:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,17 @@
|
|||
from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE
|
||||
from server.knowledge_base.utils import get_file_path, list_kbs_from_folder, list_docs_from_folder, KnowledgeFile
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.repository.knowledge_file_repository import add_doc_to_db
|
||||
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
||||
list_files_from_folder, run_in_thread_pool,
|
||||
files2docs_in_thread,
|
||||
KnowledgeFile,)
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
|
||||
from server.db.repository.knowledge_file_repository import add_file_to_db
|
||||
from server.db.base import Base, engine
|
||||
import os
|
||||
from typing import Literal, Callable, Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Literal, Callable, Any, List
|
||||
|
||||
|
||||
pool = ThreadPoolExecutor(os.cpu_count())
|
||||
|
||||
|
||||
def create_tables():
|
||||
|
|
@ -16,13 +23,22 @@ def reset_tables():
|
|||
create_tables()
|
||||
|
||||
|
||||
def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
|
||||
kb_files = []
|
||||
for file in files:
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name)
|
||||
kb_files.append(kb_file)
|
||||
except Exception as e:
|
||||
print(f"{e},已跳过")
|
||||
return kb_files
|
||||
|
||||
|
||||
def folder2db(
|
||||
kb_name: str,
|
||||
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
|
||||
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
callback_before: Callable = None,
|
||||
callback_after: Callable = None,
|
||||
):
|
||||
'''
|
||||
use existed files in local folder to populate database and/or vector store.
|
||||
|
|
@ -36,70 +52,59 @@ def folder2db(
|
|||
kb.create_kb()
|
||||
|
||||
if mode == "recreate_vs":
|
||||
files_count = kb.count_files()
|
||||
print(f"知识库 {kb_name} 中共有 {files_count} 个文档。\n即将清除向量库。")
|
||||
kb.clear_vs()
|
||||
docs = list_docs_from_folder(kb_name)
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, kb_name)
|
||||
if callable(callback_before):
|
||||
callback_before(kb_file, i, docs)
|
||||
if i == len(docs) - 1:
|
||||
not_refresh_vs_cache = False
|
||||
else:
|
||||
not_refresh_vs_cache = True
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
if callable(callback_after):
|
||||
callback_after(kb_file, i, docs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
files_count = kb.count_files()
|
||||
print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。")
|
||||
|
||||
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
|
||||
for success, result in files2docs_in_thread(kb_files, pool=pool):
|
||||
if success:
|
||||
_, filename, docs = result
|
||||
print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
|
||||
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
|
||||
else:
|
||||
print(result)
|
||||
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.refresh_vs_cache()
|
||||
elif mode == "fill_info_only":
|
||||
docs = list_docs_from_folder(kb_name)
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, kb_name)
|
||||
if callable(callback_before):
|
||||
callback_before(kb_file, i, docs)
|
||||
add_doc_to_db(kb_file)
|
||||
if callable(callback_after):
|
||||
callback_after(kb_file, i, docs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
files = list_files_from_folder(kb_name)
|
||||
kb_files = file_to_kbfile(kb_name, files)
|
||||
|
||||
for kb_file in kb_file:
|
||||
add_file_to_db(kb_file)
|
||||
print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
|
||||
elif mode == "update_in_db":
|
||||
docs = kb.list_docs()
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, kb_name)
|
||||
if callable(callback_before):
|
||||
callback_before(kb_file, i, docs)
|
||||
if i == len(docs) - 1:
|
||||
not_refresh_vs_cache = False
|
||||
else:
|
||||
not_refresh_vs_cache = True
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
if callable(callback_after):
|
||||
callback_after(kb_file, i, docs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
files = kb.list_files()
|
||||
kb_files = file_to_kbfile(kb_name, files)
|
||||
|
||||
for kb_file in kb_files:
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=True)
|
||||
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.refresh_vs_cache()
|
||||
elif mode == "increament":
|
||||
db_docs = kb.list_docs()
|
||||
folder_docs = list_docs_from_folder(kb_name)
|
||||
docs = list(set(folder_docs) - set(db_docs))
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, kb_name)
|
||||
if callable(callback_before):
|
||||
callback_before(kb_file, i, docs)
|
||||
if i == len(docs) - 1:
|
||||
not_refresh_vs_cache = False
|
||||
else:
|
||||
not_refresh_vs_cache = True
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
if callable(callback_after):
|
||||
callback_after(kb_file, i, docs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
db_files = kb.list_files()
|
||||
folder_files = list_files_from_folder(kb_name)
|
||||
files = list(set(folder_files) - set(db_files))
|
||||
kb_files = file_to_kbfile(kb_name, files)
|
||||
|
||||
for success, result in files2docs_in_thread(kb_files, pool=pool):
|
||||
if success:
|
||||
_, filename, docs = result
|
||||
print(f"正在将 {kb_name}/{filename} 添加到向量库")
|
||||
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
|
||||
else:
|
||||
print(result)
|
||||
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.refresh_vs_cache()
|
||||
else:
|
||||
raise ValueError(f"unspported migrate mode: {mode}")
|
||||
print(f"unspported migrate mode: {mode}")
|
||||
|
||||
|
||||
def recreate_all_vs(
|
||||
|
|
@ -114,30 +119,31 @@ def recreate_all_vs(
|
|||
folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs)
|
||||
|
||||
|
||||
def prune_db_docs(kb_name: str):
|
||||
def prune_db_files(kb_name: str):
|
||||
'''
|
||||
delete docs in database that not existed in local folder.
|
||||
it is used to delete database docs after user deleted some doc files in file browser
|
||||
delete files in database that not existed in local folder.
|
||||
it is used to delete database files after user deleted some doc files in file browser
|
||||
'''
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
if kb.exists():
|
||||
docs_in_db = kb.list_docs()
|
||||
docs_in_folder = list_docs_from_folder(kb_name)
|
||||
docs = list(set(docs_in_db) - set(docs_in_folder))
|
||||
for doc in docs:
|
||||
kb.delete_doc(KnowledgeFile(doc, kb_name))
|
||||
return docs
|
||||
files_in_db = kb.list_files()
|
||||
files_in_folder = list_files_from_folder(kb_name)
|
||||
files = list(set(files_in_db) - set(files_in_folder))
|
||||
kb_files = file_to_kbfile(kb_name, files)
|
||||
for kb_file in kb_files:
|
||||
kb.delete_doc(kb_file)
|
||||
return kb_files
|
||||
|
||||
def prune_folder_docs(kb_name: str):
|
||||
def prune_folder_files(kb_name: str):
|
||||
'''
|
||||
delete doc files in local folder that not existed in database.
|
||||
is is used to free local disk space by delete unused doc files.
|
||||
'''
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
if kb.exists():
|
||||
docs_in_db = kb.list_docs()
|
||||
docs_in_folder = list_docs_from_folder(kb_name)
|
||||
docs = list(set(docs_in_folder) - set(docs_in_db))
|
||||
for doc in docs:
|
||||
os.remove(get_file_path(kb_name, doc))
|
||||
return docs
|
||||
files_in_db = kb.list_files()
|
||||
files_in_folder = list_files_from_folder(kb_name)
|
||||
files = list(set(files_in_folder) - set(files_in_db))
|
||||
for file in files:
|
||||
os.remove(get_file_path(kb_name, file))
|
||||
return files
|
||||
|
|
|
|||
|
|
@ -16,7 +16,22 @@ import langchain.document_loaders
|
|||
from langchain.docstore.document import Document
|
||||
from pathlib import Path
|
||||
import json
|
||||
from typing import List, Union, Callable, Dict, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||
|
||||
|
||||
# make HuggingFaceEmbeddings hashable
|
||||
def _embeddings_hash(self):
|
||||
if isinstance(self, HuggingFaceEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, HuggingFaceBgeEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, OpenAIEmbeddings):
|
||||
return hash(self.model)
|
||||
|
||||
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
|
||||
OpenAIEmbeddings.__hash__ = _embeddings_hash
|
||||
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
|
|
@ -47,7 +62,7 @@ def list_kbs_from_folder():
|
|||
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
|
||||
|
||||
|
||||
def list_docs_from_folder(kb_name: str):
|
||||
def list_files_from_folder(kb_name: str):
|
||||
doc_path = get_doc_path(kb_name)
|
||||
return [file for file in os.listdir(doc_path)
|
||||
if os.path.isfile(os.path.join(doc_path, file))]
|
||||
|
|
@ -175,8 +190,11 @@ class KnowledgeFile:
|
|||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
self.text_splitter_name = None
|
||||
|
||||
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE):
|
||||
print(self.document_loader_name)
|
||||
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False):
|
||||
if self.docs is not None and not refresh:
|
||||
return self.docs
|
||||
|
||||
print(f"{self.document_loader_name} used for {self.filepath}")
|
||||
try:
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
|
||||
|
|
@ -193,9 +211,9 @@ class KnowledgeFile:
|
|||
elif self.document_loader_name == "CustomJSONLoader":
|
||||
loader = DocumentLoader(self.filepath, text_content=False)
|
||||
elif self.document_loader_name == "UnstructuredMarkdownLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
|
||||
loader = DocumentLoader(self.filepath, mode="elements")
|
||||
elif self.document_loader_name == "UnstructuredHTMLLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
|
||||
loader = DocumentLoader(self.filepath, mode="elements")
|
||||
else:
|
||||
loader = DocumentLoader(self.filepath)
|
||||
|
||||
|
|
@ -231,4 +249,63 @@ class KnowledgeFile:
|
|||
print(docs[0])
|
||||
if using_zh_title_enhance:
|
||||
docs = zh_title_enhance(docs)
|
||||
self.docs = docs
|
||||
return docs
|
||||
|
||||
def get_mtime(self):
|
||||
return os.path.getmtime(self.filepath)
|
||||
|
||||
def get_size(self):
|
||||
return os.path.getsize(self.filepath)
|
||||
|
||||
|
||||
def run_in_thread_pool(
|
||||
func: Callable,
|
||||
params: List[Dict] = [],
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
||||
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
||||
'''
|
||||
tasks = []
|
||||
if pool is None:
|
||||
pool = ThreadPoolExecutor()
|
||||
|
||||
for kwargs in params:
|
||||
thread = pool.submit(func, **kwargs)
|
||||
tasks.append(thread)
|
||||
|
||||
for obj in as_completed(tasks):
|
||||
yield obj.result()
|
||||
|
||||
|
||||
def files2docs_in_thread(
|
||||
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
利用多线程批量将文件转化成langchain Document.
|
||||
生成器返回值为{(kb_name, file_name): docs}
|
||||
'''
|
||||
def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]:
|
||||
try:
|
||||
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||
except Exception as e:
|
||||
return False, e
|
||||
|
||||
kwargs_list = []
|
||||
for i, file in enumerate(files):
|
||||
kwargs = {}
|
||||
if isinstance(file, tuple) and len(file) >= 2:
|
||||
files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
|
||||
elif isinstance(file, dict):
|
||||
filename = file.pop("filename")
|
||||
kb_name = file.pop("kb_name")
|
||||
files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kwargs = file
|
||||
kwargs["file"] = file
|
||||
kwargs_list.append(kwargs)
|
||||
|
||||
for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool):
|
||||
yield result
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ def test_upload_doc(api="/knowledge_base/upload_doc"):
|
|||
assert data["msg"] == f"成功上传文件 {name}"
|
||||
|
||||
|
||||
def test_list_docs(api="/knowledge_base/list_docs"):
|
||||
def test_list_files(api="/knowledge_base/list_files"):
|
||||
url = api_base_url + api
|
||||
print("\n获取知识库中文件列表:")
|
||||
r = requests.get(url, params={"knowledge_base_name": kb})
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from st_aggrid import AgGrid, JsCode
|
|||
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||||
import pandas as pd
|
||||
from server.knowledge_base.utils import get_file_path, LOADER_DICT
|
||||
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details
|
||||
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
|
||||
from typing import Literal, Dict, Tuple
|
||||
from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE
|
||||
import os
|
||||
|
|
@ -152,7 +152,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
|
||||
# 知识库详情
|
||||
# st.info("请选择文件,点击按钮进行操作。")
|
||||
doc_details = pd.DataFrame(get_kb_doc_details(kb))
|
||||
doc_details = pd.DataFrame(get_kb_file_details(kb))
|
||||
if not len(doc_details):
|
||||
st.info(f"知识库 `{kb}` 中暂无文件")
|
||||
else:
|
||||
|
|
@ -160,7 +160,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作")
|
||||
doc_details.drop(columns=["kb_name"], inplace=True)
|
||||
doc_details = doc_details[[
|
||||
"No", "file_name", "document_loader", "text_splitter", "in_folder", "in_db",
|
||||
"No", "file_name", "document_loader", "docs_count", "in_folder", "in_db",
|
||||
]]
|
||||
# doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×")
|
||||
# doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×")
|
||||
|
|
@ -172,7 +172,8 @@ def knowledge_base_page(api: ApiRequest):
|
|||
# ("file_ext", "文档类型"): {},
|
||||
# ("file_version", "文档版本"): {},
|
||||
("document_loader", "文档加载器"): {},
|
||||
("text_splitter", "分词器"): {},
|
||||
("docs_count", "文档数量"): {},
|
||||
# ("text_splitter", "分词器"): {},
|
||||
# ("create_time", "创建时间"): {},
|
||||
("in_folder", "源文件"): {"cellRenderer": cell_renderer},
|
||||
("in_db", "向量库"): {"cellRenderer": cell_renderer},
|
||||
|
|
|
|||
|
|
@ -494,18 +494,18 @@ class ApiRequest:
|
|||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/list_docs接口
|
||||
对应api.py/knowledge_base/list_files接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import list_docs
|
||||
response = run_async(list_docs(knowledge_base_name))
|
||||
from server.knowledge_base.kb_doc_api import list_files
|
||||
response = run_async(list_files(knowledge_base_name))
|
||||
return response.data
|
||||
else:
|
||||
response = self.get(
|
||||
"/knowledge_base/list_docs",
|
||||
"/knowledge_base/list_files",
|
||||
params={"knowledge_base_name": knowledge_base_name}
|
||||
)
|
||||
data = self._check_httpx_json_response(response)
|
||||
|
|
|
|||
Loading…
Reference in New Issue