增加数据库字段,重建知识库使用多线程 (#1280)

* close #1172: 给webui_page/utils添加一些log信息,方便定位错误

* 修复:重建知识库时页面未实时显示进度

* skip model_worker running when using online model api such as chatgpt

* 修改知识库管理相关内容:
1.KnowledgeFileModel增加3个字段:file_mtime(文件修改时间),file_size(文件大小),custom_docs(是否使用自定义docs)。为后面比对上传文件做准备。
2.给所有String字段加上长度,防止mysql建表错误(pr#1177)
3.统一[faiss/milvus/pgvector]_kb_service.add_doc接口,使其支持自定义docs
4.为faiss_kb_service增加一些方法,便于调用
5.为KnowledgeFile增加一些方法,便于获取文件信息,缓存file2text的结果。

* 修复/chat/fastchat无法流式输出的问题

* 新增功能:
1、KnowledgeFileModel增加"docs_count"字段,代表该文件加载到向量库中的Document数量,并在WEBUI中进行展示。
2、重建知识库`python init_database.py --recreate-vs`支持多线程。

其它:
统一代码中知识库相关函数用词:file代表一个文件名称或路径,doc代表langchain加载后的Document。部分与API接口有关或含义重叠的函数暂未修改。

---------

Co-authored-by: liunux4odoo <liunux@qq.com>, hongkong9771
This commit is contained in:
liunux4odoo 2023-08-28 13:50:35 +08:00 committed by GitHub
parent 89e3e9a691
commit 3acbf4d5d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 316 additions and 219 deletions

2
.gitignore vendored
View File

@ -5,3 +5,5 @@ logs
__pycache__/
knowledge_base/
configs/*.py
.vscode/
.pytest_cache/

View File

@ -1,8 +1,9 @@
from server.knowledge_base.migrate import create_tables, folder2db, recreate_all_vs, list_kbs_from_folder
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, recreate_all_vs, list_kbs_from_folder
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from startup import dump_server_info
from datetime import datetime
if __name__ == "__main__":
@ -25,13 +26,19 @@ if __name__ == "__main__":
dump_server_info()
create_tables()
print("database talbes created")
start_time = datetime.now()
if args.recreate_vs:
reset_tables()
print("database talbes reseted")
print("recreating all vector stores")
recreate_all_vs()
else:
create_tables()
print("database talbes created")
print("filling kb infos to database")
for kb in list_kbs_from_folder():
folder2db(kb, "fill_info_only")
end_time = datetime.now()
print(f"总计用时: {end_time-start_time}")

View File

@ -14,7 +14,7 @@ from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
@ -84,11 +84,11 @@ def create_app():
summary="删除知识库"
)(delete_kb)
app.get("/knowledge_base/list_docs",
app.get("/knowledge_base/list_files",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库内的文件列表"
)(list_docs)
)(list_files)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],

View File

@ -29,23 +29,22 @@ async def openai_chat(msg: OpenAiChatMsgIn):
print(f"{openai.api_base=}")
print(msg)
async def get_response(msg):
def get_response(msg):
data = msg.dict()
data["streaming"] = True
data.pop("stream")
try:
response = openai.ChatCompletion.create(**data)
if msg.stream:
for chunk in response.choices[0].message.content:
print(chunk)
yield chunk
for data in response:
if choices := data.choices:
if chunk := choices[0].get("delta", {}).get("content"):
print(chunk, end="", flush=True)
yield chunk
else:
answer = ""
for chunk in response.choices[0].message.content:
answer += chunk
print(answer)
yield(answer)
if response.choices:
answer = response.choices[0].message.content
print(answer)
yield(answer)
except Exception as e:
print(type(e))
logger.error(e)

View File

@ -9,9 +9,9 @@ class KnowledgeBaseModel(Base):
"""
__tablename__ = 'knowledge_base'
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID')
kb_name = Column(String, comment='知识库名称')
vs_type = Column(String, comment='嵌入模型类型')
embed_model = Column(String, comment='嵌入模型名称')
kb_name = Column(String(50), comment='知识库名称')
vs_type = Column(String(50), comment='向量库类型')
embed_model = Column(String(50), comment='嵌入模型名称')
file_count = Column(Integer, default=0, comment='文件数量')
create_time = Column(DateTime, default=func.now(), comment='创建时间')

View File

@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, func
from server.db.base import Base
@ -9,12 +9,16 @@ class KnowledgeFileModel(Base):
"""
__tablename__ = 'knowledge_file'
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识文件ID')
file_name = Column(String, comment='文件名')
file_ext = Column(String, comment='文件扩展名')
kb_name = Column(String, comment='所属知识库名称')
document_loader_name = Column(String, comment='文档加载器名称')
text_splitter_name = Column(String, comment='文本分割器名称')
file_name = Column(String(255), comment='文件名')
file_ext = Column(String(10), comment='文件扩展名')
kb_name = Column(String(50), comment='所属知识库名称')
document_loader_name = Column(String(50), comment='文档加载器名称')
text_splitter_name = Column(String(50), comment='文本分割器名称')
file_version = Column(Integer, default=1, comment='文件版本')
file_mtime = Column(Float, default=0.0, comment="文件修改时间")
file_size = Column(Integer, default=0, comment="文件大小")
custom_docs = Column(Boolean, default=False, comment="是否自定义docs")
docs_count = Column(Integer, default=0, comment="切分文档数量")
create_time = Column(DateTime, default=func.now(), comment='创建时间')
def __repr__(self):

View File

@ -5,20 +5,37 @@ from server.knowledge_base.utils import KnowledgeFile
@with_session
def list_docs_from_db(session, kb_name):
def count_files_from_db(session, kb_name: str) -> int:
return session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).count()
@with_session
def list_files_from_db(session, kb_name):
files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all()
docs = [f.file_name for f in files]
return docs
@with_session
def add_doc_to_db(session, kb_file: KnowledgeFile):
def add_file_to_db(session,
kb_file: KnowledgeFile,
docs_count: int = 0,
custom_docs: bool = False,):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
if kb:
# 如果已经存在该文件,则更新文件版本号
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name).first()
# 如果已经存在该文件,则更新文件信息与版本号
existing_file: KnowledgeFileModel = (session.query(KnowledgeFileModel)
.filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name)
.first())
mtime = kb_file.get_mtime()
size = kb_file.get_size()
if existing_file:
existing_file.file_mtime = mtime
existing_file.file_size = size
existing_file.docs_count = docs_count
existing_file.custom_docs = custom_docs
existing_file.file_version += 1
# 否则,添加新文件
else:
@ -28,6 +45,10 @@ def add_doc_to_db(session, kb_file: KnowledgeFile):
kb_name=kb_file.kb_name,
document_loader_name=kb_file.document_loader_name,
text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter",
file_mtime=mtime,
file_size=size,
docs_count = docs_count,
custom_docs=custom_docs,
)
kb.file_count += 1
session.add(new_file)
@ -62,7 +83,7 @@ def delete_files_from_db(session, knowledge_base_name: str):
@with_session
def doc_exists(session, kb_file: KnowledgeFile):
def file_exists_in_db(session, kb_file: KnowledgeFile):
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name).first()
return True if existing_file else False
@ -82,6 +103,10 @@ def get_file_detail(session, kb_name: str, filename: str) -> dict:
"document_loader": file.document_loader_name,
"text_splitter": file.text_splitter_name,
"create_time": file.create_time,
"file_mtime": file.file_mtime,
"file_size": file.file_size,
"custom_docs": file.custom_docs,
"docs_count": file.docs_count,
}
else:
return {}

View File

@ -3,7 +3,7 @@ import urllib
from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile
from fastapi.responses import StreamingResponse, FileResponse
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
@ -29,7 +29,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
return data
async def list_docs(
async def list_files(
knowledge_base_name: str
) -> ListResponse:
if not validate_kb_name(knowledge_base_name):
@ -40,7 +40,7 @@ async def list_docs(
if kb is None:
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else:
all_doc_names = kb.list_docs()
all_doc_names = kb.list_files()
return ListResponse(data=all_doc_names)
@ -190,7 +190,7 @@ async def recreate_vector_store(
else:
kb.create_kb()
kb.clear_vs()
docs = list_docs_from_folder(knowledge_base_name)
docs = list_files_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)

View File

@ -13,15 +13,15 @@ from server.db.repository.knowledge_base_repository import (
load_kb_from_db, get_kb_detail,
)
from server.db.repository.knowledge_file_repository import (
add_doc_to_db, delete_file_from_db, delete_files_from_db, doc_exists,
list_docs_from_db, get_file_detail, delete_file_from_db
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db
)
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_DEVICE, EMBEDDING_MODEL)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
list_kbs_from_folder, list_docs_from_folder,
list_kbs_from_folder, list_files_from_folder,
)
from typing import List, Union, Dict
@ -74,16 +74,22 @@ class KBService(ABC):
status = delete_kb_from_db(self.kb_name)
return status
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
向知识库添加文件
如果指定了docs则不再将文本向量化并将数据库对应条目标为custom_docs=True
"""
docs = kb_file.file2text()
if docs:
custom_docs = True
else:
docs = kb_file.file2text()
custom_docs = False
if docs:
self.delete_doc(kb_file)
embeddings = self._load_embeddings()
self.do_add_doc(docs, embeddings, **kwargs)
status = add_doc_to_db(kb_file)
self.do_add_doc(docs, embeddings=embeddings, **kwargs)
status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs))
else:
status = False
return status
@ -98,20 +104,24 @@ class KBService(ABC):
os.remove(kb_file.filepath)
return status
def update_doc(self, kb_file: KnowledgeFile, **kwargs):
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
使用content中的文件更新向量库
如果指定了docs则使用自定义docs并将数据库对应条目标为custom_docs=True
"""
if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, **kwargs)
return self.add_doc(kb_file, docs=docs, **kwargs)
def exist_doc(self, file_name: str):
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
filename=file_name))
def list_docs(self):
return list_docs_from_db(self.kb_name)
def list_files(self):
return list_files_from_db(self.kb_name)
def count_files(self):
return count_files_from_db(self.kb_name)
def search_docs(self,
query: str,
@ -264,25 +274,26 @@ def get_kb_details() -> List[Dict]:
return data
def get_kb_doc_details(kb_name: str) -> List[Dict]:
def get_kb_file_details(kb_name: str) -> List[Dict]:
kb = KBServiceFactory.get_service_by_name(kb_name)
docs_in_folder = list_docs_from_folder(kb_name)
docs_in_db = kb.list_docs()
files_in_folder = list_files_from_folder(kb_name)
files_in_db = kb.list_files()
result = {}
for doc in docs_in_folder:
for doc in files_in_folder:
result[doc] = {
"kb_name": kb_name,
"file_name": doc,
"file_ext": os.path.splitext(doc)[-1],
"file_version": 0,
"document_loader": "",
"docs_count": 0,
"text_splitter": "",
"create_time": None,
"in_folder": True,
"in_db": False,
}
for doc in docs_in_db:
for doc in files_in_db:
doc_detail = get_file_detail(kb_name, doc)
if doc_detail:
doc_detail["in_db"] = True

View File

@ -13,34 +13,16 @@ from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from typing import List
from langchain.docstore.document import Document
from server.utils import torch_gc
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
if isinstance(self, HuggingFaceEmbeddings):
return hash(self.model_name)
elif isinstance(self, HuggingFaceBgeEmbeddings):
return hash(self.model_name)
elif isinstance(self, OpenAIEmbeddings):
return hash(self.model)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
OpenAIEmbeddings.__hash__ = _embeddings_hash
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
_VECTOR_STORE_TICKS = {}
_VECTOR_STORE_TICKS = {}
@lru_cache(CACHED_VS_NUM)
def load_vector_store(
def load_faiss_vector_store(
knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = EMBEDDING_DEVICE,
@ -86,22 +68,30 @@ class FaissKBService(KBService):
def vs_type(self) -> str:
return SupportedVSType.FAISS
@staticmethod
def get_vs_path(knowledge_base_name: str):
return os.path.join(FaissKBService.get_kb_path(knowledge_base_name), "vector_store")
def get_vs_path(self):
return os.path.join(self.get_kb_path(), "vector_store")
@staticmethod
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self):
return load_faiss_vector_store(
knowledge_base_name=self.kb_name,
embed_model=self.embed_model,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
)
def refresh_vs_cache(self):
refresh_vs_cache(self.kb_name)
def do_init(self):
self.kb_path = FaissKBService.get_kb_path(self.kb_name)
self.vs_path = FaissKBService.get_vs_path(self.kb_name)
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
load_vector_store(self.kb_name)
self.load_vector_store()
def do_drop_kb(self):
self.clear_vs()
@ -113,9 +103,7 @@ class FaissKBService(KBService):
score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]:
search_index = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
search_index = self.load_vector_store()
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
return docs
@ -124,22 +112,18 @@ class FaissKBService(KBService):
embeddings: Embeddings,
**kwargs,
):
vector_store = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
vector_store = self.load_vector_store()
vector_store.add_documents(docs)
torch_gc()
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
self.refresh_vs_cache()
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
embeddings = self._load_embeddings()
vector_store = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
vector_store = self.load_vector_store()
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
@ -148,14 +132,14 @@ class FaissKBService(KBService):
vector_store.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
self.refresh_vs_cache()
return True
def do_clear_vs(self):
shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)
refresh_vs_cache(self.kb_name)
self.refresh_vs_cache()
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):
@ -166,10 +150,11 @@ class FaissKBService(KBService):
return "in_folder"
else:
return False
if __name__ == '__main__':
milvusService = FaissKBService("test")
milvusService.add_doc(KnowledgeFile("README.md", "test"))
milvusService.delete_doc(KnowledgeFile("README.md", "test"))
milvusService.do_drop_kb()
print(milvusService.search_docs("如何启动api服务"))
if __name__ == '__main__':
faissService = FaissKBService("test")
faissService.add_doc(KnowledgeFile("README.md", "test"))
faissService.delete_doc(KnowledgeFile("README.md", "test"))
faissService.do_drop_kb()
print(faissService.search_docs("如何启动api服务"))

View File

@ -47,24 +47,15 @@ class MilvusKBService(KBService):
self._load_milvus()
def do_drop_kb(self):
self.milvus.col.drop()
if self.milvus.col:
self.milvus.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
docs = kb_file.file2text()
self.milvus.add_documents(docs)
from server.db.repository.knowledge_file_repository import add_doc_to_db
status = add_doc_to_db(kb_file)
return status
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
pass
self.milvus.add_documents(docs)
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
filepath = kb_file.filepath.replace('\\', '\\\\')

View File

@ -47,23 +47,12 @@ class PGKBService(KBService):
connect.commit()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
# todo: support score threshold
self._load_pg_vector(embeddings=embeddings)
return score_threshold_process(score_threshold, top_k,
self.pg_vector.similarity_search_with_score(query, top_k))
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
docs = kb_file.file2text()
self.pg_vector.add_documents(docs)
from server.db.repository.knowledge_file_repository import add_doc_to_db
status = add_doc_to_db(kb_file)
return status
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
pass
self.pg_vector.add_documents(docs)
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect:

View File

@ -1,10 +1,17 @@
from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE
from server.knowledge_base.utils import get_file_path, list_kbs_from_folder, list_docs_from_folder, KnowledgeFile
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import add_doc_to_db
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
list_files_from_folder, run_in_thread_pool,
files2docs_in_thread,
KnowledgeFile,)
from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
from server.db.repository.knowledge_file_repository import add_file_to_db
from server.db.base import Base, engine
import os
from typing import Literal, Callable, Any
from concurrent.futures import ThreadPoolExecutor
from typing import Literal, Callable, Any, List
pool = ThreadPoolExecutor(os.cpu_count())
def create_tables():
@ -16,13 +23,22 @@ def reset_tables():
create_tables()
def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
kb_files = []
for file in files:
try:
kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name)
kb_files.append(kb_file)
except Exception as e:
print(f"{e},已跳过")
return kb_files
def folder2db(
kb_name: str,
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
callback_before: Callable = None,
callback_after: Callable = None,
):
'''
use existed files in local folder to populate database and/or vector store.
@ -36,70 +52,59 @@ def folder2db(
kb.create_kb()
if mode == "recreate_vs":
files_count = kb.count_files()
print(f"知识库 {kb_name} 中共有 {files_count} 个文档。\n即将清除向量库。")
kb.clear_vs()
docs = list_docs_from_folder(kb_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
print(e)
files_count = kb.count_files()
print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。")
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
for success, result in files2docs_in_thread(kb_files, pool=pool):
if success:
_, filename, docs = result
print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
else:
print(result)
if kb.vs_type() == SupportedVSType.FAISS:
kb.refresh_vs_cache()
elif mode == "fill_info_only":
docs = list_docs_from_folder(kb_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
add_doc_to_db(kb_file)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
print(e)
files = list_files_from_folder(kb_name)
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_file:
add_file_to_db(kb_file)
print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
elif mode == "update_in_db":
docs = kb.list_docs()
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
print(e)
files = kb.list_files()
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.update_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.refresh_vs_cache()
elif mode == "increament":
db_docs = kb.list_docs()
folder_docs = list_docs_from_folder(kb_name)
docs = list(set(folder_docs) - set(db_docs))
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
print(e)
db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files))
kb_files = file_to_kbfile(kb_name, files)
for success, result in files2docs_in_thread(kb_files, pool=pool):
if success:
_, filename, docs = result
print(f"正在将 {kb_name}/{filename} 添加到向量库")
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
else:
print(result)
if kb.vs_type() == SupportedVSType.FAISS:
kb.refresh_vs_cache()
else:
raise ValueError(f"unspported migrate mode: {mode}")
print(f"unspported migrate mode: {mode}")
def recreate_all_vs(
@ -114,30 +119,31 @@ def recreate_all_vs(
folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs)
def prune_db_docs(kb_name: str):
def prune_db_files(kb_name: str):
'''
delete docs in database that not existed in local folder.
it is used to delete database docs after user deleted some doc files in file browser
delete files in database that not existed in local folder.
it is used to delete database files after user deleted some doc files in file browser
'''
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb.exists():
docs_in_db = kb.list_docs()
docs_in_folder = list_docs_from_folder(kb_name)
docs = list(set(docs_in_db) - set(docs_in_folder))
for doc in docs:
kb.delete_doc(KnowledgeFile(doc, kb_name))
return docs
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.delete_doc(kb_file)
return kb_files
def prune_folder_docs(kb_name: str):
def prune_folder_files(kb_name: str):
'''
delete doc files in local folder that not existed in database.
is is used to free local disk space by delete unused doc files.
'''
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb.exists():
docs_in_db = kb.list_docs()
docs_in_folder = list_docs_from_folder(kb_name)
docs = list(set(docs_in_folder) - set(docs_in_db))
for doc in docs:
os.remove(get_file_path(kb_name, doc))
return docs
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_folder) - set(files_in_db))
for file in files:
os.remove(get_file_path(kb_name, file))
return files

View File

@ -16,7 +16,22 @@ import langchain.document_loaders
from langchain.docstore.document import Document
from pathlib import Path
import json
from typing import List, Union, Callable, Dict, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
if isinstance(self, HuggingFaceEmbeddings):
return hash(self.model_name)
elif isinstance(self, HuggingFaceBgeEmbeddings):
return hash(self.model_name)
elif isinstance(self, OpenAIEmbeddings):
return hash(self.model)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
OpenAIEmbeddings.__hash__ = _embeddings_hash
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
def validate_kb_name(knowledge_base_id: str) -> bool:
@ -47,7 +62,7 @@ def list_kbs_from_folder():
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
def list_docs_from_folder(kb_name: str):
def list_files_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
return [file for file in os.listdir(doc_path)
if os.path.isfile(os.path.join(doc_path, file))]
@ -175,8 +190,11 @@ class KnowledgeFile:
# TODO: 增加依据文件格式匹配text_splitter
self.text_splitter_name = None
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE):
print(self.document_loader_name)
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False):
if self.docs is not None and not refresh:
return self.docs
print(f"{self.document_loader_name} used for {self.filepath}")
try:
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
@ -193,9 +211,9 @@ class KnowledgeFile:
elif self.document_loader_name == "CustomJSONLoader":
loader = DocumentLoader(self.filepath, text_content=False)
elif self.document_loader_name == "UnstructuredMarkdownLoader":
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
loader = DocumentLoader(self.filepath, mode="elements")
elif self.document_loader_name == "UnstructuredHTMLLoader":
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
loader = DocumentLoader(self.filepath, mode="elements")
else:
loader = DocumentLoader(self.filepath)
@ -231,4 +249,63 @@ class KnowledgeFile:
print(docs[0])
if using_zh_title_enhance:
docs = zh_title_enhance(docs)
self.docs = docs
return docs
def get_mtime(self):
return os.path.getmtime(self.filepath)
def get_size(self):
return os.path.getsize(self.filepath)
def run_in_thread_pool(
func: Callable,
params: List[Dict] = [],
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
在线程池中批量运行任务并将运行结果以生成器的形式返回
请确保任务中的所有操作是线程安全的任务函数请全部使用关键字参数
'''
tasks = []
if pool is None:
pool = ThreadPoolExecutor()
for kwargs in params:
thread = pool.submit(func, **kwargs)
tasks.append(thread)
for obj in as_completed(tasks):
yield obj.result()
def files2docs_in_thread(
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
利用多线程批量将文件转化成langchain Document.
生成器返回值为{(kb_name, file_name): docs}
'''
def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
except Exception as e:
return False, e
kwargs_list = []
for i, file in enumerate(files):
kwargs = {}
if isinstance(file, tuple) and len(file) >= 2:
files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
elif isinstance(file, dict):
filename = file.pop("filename")
kb_name = file.pop("kb_name")
files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kwargs = file
kwargs["file"] = file
kwargs_list.append(kwargs)
for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool):
yield result

View File

@ -112,7 +112,7 @@ def test_upload_doc(api="/knowledge_base/upload_doc"):
assert data["msg"] == f"成功上传文件 {name}"
def test_list_docs(api="/knowledge_base/list_docs"):
def test_list_files(api="/knowledge_base/list_files"):
url = api_base_url + api
print("\n获取知识库中文件列表:")
r = requests.get(url, params={"knowledge_base_name": kb})

View File

@ -4,7 +4,7 @@ from st_aggrid import AgGrid, JsCode
from st_aggrid.grid_options_builder import GridOptionsBuilder
import pandas as pd
from server.knowledge_base.utils import get_file_path, LOADER_DICT
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple
from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE
import os
@ -152,7 +152,7 @@ def knowledge_base_page(api: ApiRequest):
# 知识库详情
# st.info("请选择文件,点击按钮进行操作。")
doc_details = pd.DataFrame(get_kb_doc_details(kb))
doc_details = pd.DataFrame(get_kb_file_details(kb))
if not len(doc_details):
st.info(f"知识库 `{kb}` 中暂无文件")
else:
@ -160,7 +160,7 @@ def knowledge_base_page(api: ApiRequest):
st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作")
doc_details.drop(columns=["kb_name"], inplace=True)
doc_details = doc_details[[
"No", "file_name", "document_loader", "text_splitter", "in_folder", "in_db",
"No", "file_name", "document_loader", "docs_count", "in_folder", "in_db",
]]
# doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×")
# doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×")
@ -172,7 +172,8 @@ def knowledge_base_page(api: ApiRequest):
# ("file_ext", "文档类型"): {},
# ("file_version", "文档版本"): {},
("document_loader", "文档加载器"): {},
("text_splitter", "分词器"): {},
("docs_count", "文档数量"): {},
# ("text_splitter", "分词器"): {},
# ("create_time", "创建时间"): {},
("in_folder", "源文件"): {"cellRenderer": cell_renderer},
("in_db", "向量库"): {"cellRenderer": cell_renderer},

View File

@ -494,18 +494,18 @@ class ApiRequest:
no_remote_api: bool = None,
):
'''
对应api.py/knowledge_base/list_docs接口
对应api.py/knowledge_base/list_files接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if no_remote_api:
from server.knowledge_base.kb_doc_api import list_docs
response = run_async(list_docs(knowledge_base_name))
from server.knowledge_base.kb_doc_api import list_files
response = run_async(list_files(knowledge_base_name))
return response.data
else:
response = self.get(
"/knowledge_base/list_docs",
"/knowledge_base/list_files",
params={"knowledge_base_name": knowledge_base_name}
)
data = self._check_httpx_json_response(response)