From 3acbf4d5d1f1a76d86c5302cc6486d261edfbee4 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Mon, 28 Aug 2023 13:50:35 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E5=AD=97=E6=AE=B5=EF=BC=8C=E9=87=8D=E5=BB=BA=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E5=BA=93=E4=BD=BF=E7=94=A8=E5=A4=9A=E7=BA=BF=E7=A8=8B=20(#1280?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * close #1172: 给webui_page/utils添加一些log信息,方便定位错误 * 修复:重建知识库时页面未实时显示进度 * skip model_worker running when using online model api such as chatgpt * 修改知识库管理相关内容: 1.KnowledgeFileModel增加3个字段:file_mtime(文件修改时间),file_size(文件大小),custom_docs(是否使用自定义docs)。为后面比对上传文件做准备。 2.给所有String字段加上长度,防止mysql建表错误(pr#1177) 3.统一[faiss/milvus/pgvector]_kb_service.add_doc接口,使其支持自定义docs 4.为faiss_kb_service增加一些方法,便于调用 5.为KnowledgeFile增加一些方法,便于获取文件信息,缓存file2text的结果。 * 修复/chat/fastchat无法流式输出的问题 * 新增功能: 1、KnowledgeFileModel增加"docs_count"字段,代表该文件加载到向量库中的Document数量,并在WEBUI中进行展示。 2、重建知识库`python init_database.py --recreate-vs`支持多线程。 其它: 统一代码中知识库相关函数用词:file代表一个文件名称或路径,doc代表langchain加载后的Document。部分与API接口有关或含义重叠的函数暂未修改。 --------- Co-authored-by: liunux4odoo , hongkong9771 --- .gitignore | 2 + init_database.py | 13 +- server/api.py | 6 +- server/chat/openai_chat.py | 21 ++- server/db/models/knowledge_base_model.py | 6 +- server/db/models/knowledge_file_model.py | 16 +- .../repository/knowledge_file_repository.py | 37 +++- server/knowledge_base/kb_doc_api.py | 8 +- server/knowledge_base/kb_service/base.py | 45 +++-- .../kb_service/faiss_kb_service.py | 77 ++++---- .../kb_service/milvus_kb_service.py | 15 +- .../kb_service/pg_kb_service.py | 13 +- server/knowledge_base/migrate.py | 168 +++++++++--------- server/knowledge_base/utils.py | 89 +++++++++- tests/api/test_kb_api.py | 2 +- webui_pages/knowledge_base/knowledge_base.py | 9 +- webui_pages/utils.py | 8 +- 17 files changed, 316 insertions(+), 219 deletions(-) diff --git a/.gitignore b/.gitignore index b5918ee..c4178a9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ logs __pycache__/ knowledge_base/ configs/*.py +.vscode/ +.pytest_cache/ diff --git a/init_database.py b/init_database.py index 7fc8494..42a18c5 100644 --- a/init_database.py +++ b/init_database.py @@ -1,8 +1,9 @@ -from server.knowledge_base.migrate import create_tables, folder2db, recreate_all_vs, list_kbs_from_folder +from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, recreate_all_vs, list_kbs_from_folder from configs.model_config import NLTK_DATA_PATH import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path from startup import dump_server_info +from datetime import datetime if __name__ == "__main__": @@ -25,13 +26,19 @@ if __name__ == "__main__": dump_server_info() - create_tables() - print("database talbes created") + start_time = datetime.now() if args.recreate_vs: + reset_tables() + print("database talbes reseted") print("recreating all vector stores") recreate_all_vs() else: + create_tables() + print("database talbes created") print("filling kb infos to database") for kb in list_kbs_from_folder(): folder2db(kb, "fill_info_only") + + end_time = datetime.now() + print(f"总计用时: {end_time-start_time}") diff --git a/server/api.py b/server/api.py index ecadd7c..fe5e156 100644 --- a/server/api.py +++ b/server/api.py @@ -14,7 +14,7 @@ from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, search_engine_chat) from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb -from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, +from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store, search_docs, DocumentWithScore) from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline @@ -84,11 +84,11 @@ def create_app(): summary="删除知识库" )(delete_kb) - app.get("/knowledge_base/list_docs", + app.get("/knowledge_base/list_files", tags=["Knowledge Base Management"], response_model=ListResponse, summary="获取知识库内的文件列表" - )(list_docs) + )(list_files) app.post("/knowledge_base/search_docs", tags=["Knowledge Base Management"], diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py index a7ad807..a799c62 100644 --- a/server/chat/openai_chat.py +++ b/server/chat/openai_chat.py @@ -29,23 +29,22 @@ async def openai_chat(msg: OpenAiChatMsgIn): print(f"{openai.api_base=}") print(msg) - async def get_response(msg): + def get_response(msg): data = msg.dict() - data["streaming"] = True - data.pop("stream") try: response = openai.ChatCompletion.create(**data) if msg.stream: - for chunk in response.choices[0].message.content: - print(chunk) - yield chunk + for data in response: + if choices := data.choices: + if chunk := choices[0].get("delta", {}).get("content"): + print(chunk, end="", flush=True) + yield chunk else: - answer = "" - for chunk in response.choices[0].message.content: - answer += chunk - print(answer) - yield(answer) + if response.choices: + answer = response.choices[0].message.content + print(answer) + yield(answer) except Exception as e: print(type(e)) logger.error(e) diff --git a/server/db/models/knowledge_base_model.py b/server/db/models/knowledge_base_model.py index 37abd4e..478bc1f 100644 --- a/server/db/models/knowledge_base_model.py +++ b/server/db/models/knowledge_base_model.py @@ -9,9 +9,9 @@ class KnowledgeBaseModel(Base): """ __tablename__ = 'knowledge_base' id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID') - kb_name = Column(String, comment='知识库名称') - vs_type = Column(String, comment='嵌入模型类型') - embed_model = Column(String, comment='嵌入模型名称') + kb_name = Column(String(50), comment='知识库名称') + vs_type = Column(String(50), comment='向量库类型') + embed_model = Column(String(50), comment='嵌入模型名称') file_count = Column(Integer, default=0, comment='文件数量') create_time = Column(DateTime, default=func.now(), comment='创建时间') diff --git a/server/db/models/knowledge_file_model.py b/server/db/models/knowledge_file_model.py index 7fffdfb..3d885ea 100644 --- a/server/db/models/knowledge_file_model.py +++ b/server/db/models/knowledge_file_model.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Integer, String, DateTime, func +from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, func from server.db.base import Base @@ -9,12 +9,16 @@ class KnowledgeFileModel(Base): """ __tablename__ = 'knowledge_file' id = Column(Integer, primary_key=True, autoincrement=True, comment='知识文件ID') - file_name = Column(String, comment='文件名') - file_ext = Column(String, comment='文件扩展名') - kb_name = Column(String, comment='所属知识库名称') - document_loader_name = Column(String, comment='文档加载器名称') - text_splitter_name = Column(String, comment='文本分割器名称') + file_name = Column(String(255), comment='文件名') + file_ext = Column(String(10), comment='文件扩展名') + kb_name = Column(String(50), comment='所属知识库名称') + document_loader_name = Column(String(50), comment='文档加载器名称') + text_splitter_name = Column(String(50), comment='文本分割器名称') file_version = Column(Integer, default=1, comment='文件版本') + file_mtime = Column(Float, default=0.0, comment="文件修改时间") + file_size = Column(Integer, default=0, comment="文件大小") + custom_docs = Column(Boolean, default=False, comment="是否自定义docs") + docs_count = Column(Integer, default=0, comment="切分文档数量") create_time = Column(DateTime, default=func.now(), comment='创建时间') def __repr__(self): diff --git a/server/db/repository/knowledge_file_repository.py b/server/db/repository/knowledge_file_repository.py index 404910f..6277ad6 100644 --- a/server/db/repository/knowledge_file_repository.py +++ b/server/db/repository/knowledge_file_repository.py @@ -5,20 +5,37 @@ from server.knowledge_base.utils import KnowledgeFile @with_session -def list_docs_from_db(session, kb_name): +def count_files_from_db(session, kb_name: str) -> int: + return session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).count() + + +@with_session +def list_files_from_db(session, kb_name): files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all() docs = [f.file_name for f in files] return docs @with_session -def add_doc_to_db(session, kb_file: KnowledgeFile): +def add_file_to_db(session, + kb_file: KnowledgeFile, + docs_count: int = 0, + custom_docs: bool = False,): kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() if kb: - # 如果已经存在该文件,则更新文件版本号 - existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename, - kb_name=kb_file.kb_name).first() + # 如果已经存在该文件,则更新文件信息与版本号 + existing_file: KnowledgeFileModel = (session.query(KnowledgeFileModel) + .filter_by(file_name=kb_file.filename, + kb_name=kb_file.kb_name) + .first()) + mtime = kb_file.get_mtime() + size = kb_file.get_size() + if existing_file: + existing_file.file_mtime = mtime + existing_file.file_size = size + existing_file.docs_count = docs_count + existing_file.custom_docs = custom_docs existing_file.file_version += 1 # 否则,添加新文件 else: @@ -28,6 +45,10 @@ def add_doc_to_db(session, kb_file: KnowledgeFile): kb_name=kb_file.kb_name, document_loader_name=kb_file.document_loader_name, text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter", + file_mtime=mtime, + file_size=size, + docs_count = docs_count, + custom_docs=custom_docs, ) kb.file_count += 1 session.add(new_file) @@ -62,7 +83,7 @@ def delete_files_from_db(session, knowledge_base_name: str): @with_session -def doc_exists(session, kb_file: KnowledgeFile): +def file_exists_in_db(session, kb_file: KnowledgeFile): existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename, kb_name=kb_file.kb_name).first() return True if existing_file else False @@ -82,6 +103,10 @@ def get_file_detail(session, kb_name: str, filename: str) -> dict: "document_loader": file.document_loader_name, "text_splitter": file.text_splitter_name, "create_time": file.create_time, + "file_mtime": file.file_mtime, + "file_size": file.file_size, + "custom_docs": file.custom_docs, + "docs_count": file.docs_count, } else: return {} diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 8a40058..7ea5d27 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -3,7 +3,7 @@ import urllib from fastapi import File, Form, Body, Query, UploadFile from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) from server.utils import BaseResponse, ListResponse -from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile +from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile from fastapi.responses import StreamingResponse, FileResponse import json from server.knowledge_base.kb_service.base import KBServiceFactory @@ -29,7 +29,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" return data -async def list_docs( +async def list_files( knowledge_base_name: str ) -> ListResponse: if not validate_kb_name(knowledge_base_name): @@ -40,7 +40,7 @@ async def list_docs( if kb is None: return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) else: - all_doc_names = kb.list_docs() + all_doc_names = kb.list_files() return ListResponse(data=all_doc_names) @@ -190,7 +190,7 @@ async def recreate_vector_store( else: kb.create_kb() kb.clear_vs() - docs = list_docs_from_folder(knowledge_base_name) + docs = list_files_from_folder(knowledge_base_name) for i, doc in enumerate(docs): try: kb_file = KnowledgeFile(doc, knowledge_base_name) diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 6fa259d..e3a0743 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -13,15 +13,15 @@ from server.db.repository.knowledge_base_repository import ( load_kb_from_db, get_kb_detail, ) from server.db.repository.knowledge_file_repository import ( - add_doc_to_db, delete_file_from_db, delete_files_from_db, doc_exists, - list_docs_from_db, get_file_detail, delete_file_from_db + add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db, + count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db ) from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, EMBEDDING_DEVICE, EMBEDDING_MODEL) from server.knowledge_base.utils import ( get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, - list_kbs_from_folder, list_docs_from_folder, + list_kbs_from_folder, list_files_from_folder, ) from typing import List, Union, Dict @@ -74,16 +74,22 @@ class KBService(ABC): status = delete_kb_from_db(self.kb_name) return status - def add_doc(self, kb_file: KnowledgeFile, **kwargs): + def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): """ 向知识库添加文件 + 如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True """ - docs = kb_file.file2text() + if docs: + custom_docs = True + else: + docs = kb_file.file2text() + custom_docs = False + if docs: self.delete_doc(kb_file) embeddings = self._load_embeddings() - self.do_add_doc(docs, embeddings, **kwargs) - status = add_doc_to_db(kb_file) + self.do_add_doc(docs, embeddings=embeddings, **kwargs) + status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs)) else: status = False return status @@ -98,20 +104,24 @@ class KBService(ABC): os.remove(kb_file.filepath) return status - def update_doc(self, kb_file: KnowledgeFile, **kwargs): + def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): """ 使用content中的文件更新向量库 + 如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True """ if os.path.exists(kb_file.filepath): self.delete_doc(kb_file, **kwargs) - return self.add_doc(kb_file, **kwargs) + return self.add_doc(kb_file, docs=docs, **kwargs) def exist_doc(self, file_name: str): - return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, + return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name, filename=file_name)) - def list_docs(self): - return list_docs_from_db(self.kb_name) + def list_files(self): + return list_files_from_db(self.kb_name) + + def count_files(self): + return count_files_from_db(self.kb_name) def search_docs(self, query: str, @@ -264,25 +274,26 @@ def get_kb_details() -> List[Dict]: return data -def get_kb_doc_details(kb_name: str) -> List[Dict]: +def get_kb_file_details(kb_name: str) -> List[Dict]: kb = KBServiceFactory.get_service_by_name(kb_name) - docs_in_folder = list_docs_from_folder(kb_name) - docs_in_db = kb.list_docs() + files_in_folder = list_files_from_folder(kb_name) + files_in_db = kb.list_files() result = {} - for doc in docs_in_folder: + for doc in files_in_folder: result[doc] = { "kb_name": kb_name, "file_name": doc, "file_ext": os.path.splitext(doc)[-1], "file_version": 0, "document_loader": "", + "docs_count": 0, "text_splitter": "", "create_time": None, "in_folder": True, "in_db": False, } - for doc in docs_in_db: + for doc in files_in_db: doc_detail = get_file_detail(kb_name, doc) if doc_detail: doc_detail["in_db"] = True diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 40c67b5..936ffad 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -13,34 +13,16 @@ from functools import lru_cache from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile from langchain.vectorstores import FAISS from langchain.embeddings.base import Embeddings -from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings -from langchain.embeddings.openai import OpenAIEmbeddings from typing import List from langchain.docstore.document import Document from server.utils import torch_gc -# make HuggingFaceEmbeddings hashable -def _embeddings_hash(self): - if isinstance(self, HuggingFaceEmbeddings): - return hash(self.model_name) - elif isinstance(self, HuggingFaceBgeEmbeddings): - return hash(self.model_name) - elif isinstance(self, OpenAIEmbeddings): - return hash(self.model) - -HuggingFaceEmbeddings.__hash__ = _embeddings_hash -OpenAIEmbeddings.__hash__ = _embeddings_hash -HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash - -_VECTOR_STORE_TICKS = {} - - _VECTOR_STORE_TICKS = {} @lru_cache(CACHED_VS_NUM) -def load_vector_store( +def load_faiss_vector_store( knowledge_base_name: str, embed_model: str = EMBEDDING_MODEL, embed_device: str = EMBEDDING_DEVICE, @@ -86,22 +68,30 @@ class FaissKBService(KBService): def vs_type(self) -> str: return SupportedVSType.FAISS - @staticmethod - def get_vs_path(knowledge_base_name: str): - return os.path.join(FaissKBService.get_kb_path(knowledge_base_name), "vector_store") + def get_vs_path(self): + return os.path.join(self.get_kb_path(), "vector_store") - @staticmethod - def get_kb_path(knowledge_base_name: str): - return os.path.join(KB_ROOT_PATH, knowledge_base_name) + def get_kb_path(self): + return os.path.join(KB_ROOT_PATH, self.kb_name) + + def load_vector_store(self): + return load_faiss_vector_store( + knowledge_base_name=self.kb_name, + embed_model=self.embed_model, + tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0), + ) + + def refresh_vs_cache(self): + refresh_vs_cache(self.kb_name) def do_init(self): - self.kb_path = FaissKBService.get_kb_path(self.kb_name) - self.vs_path = FaissKBService.get_vs_path(self.kb_name) + self.kb_path = self.get_kb_path() + self.vs_path = self.get_vs_path() def do_create_kb(self): if not os.path.exists(self.vs_path): os.makedirs(self.vs_path) - load_vector_store(self.kb_name) + self.load_vector_store() def do_drop_kb(self): self.clear_vs() @@ -113,9 +103,7 @@ class FaissKBService(KBService): score_threshold: float = SCORE_THRESHOLD, embeddings: Embeddings = None, ) -> List[Document]: - search_index = load_vector_store(self.kb_name, - embeddings=embeddings, - tick=_VECTOR_STORE_TICKS.get(self.kb_name)) + search_index = self.load_vector_store() docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) return docs @@ -124,22 +112,18 @@ class FaissKBService(KBService): embeddings: Embeddings, **kwargs, ): - vector_store = load_vector_store(self.kb_name, - embeddings=embeddings, - tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) + vector_store = self.load_vector_store() vector_store.add_documents(docs) torch_gc() if not kwargs.get("not_refresh_vs_cache"): vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) + self.refresh_vs_cache() def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): embeddings = self._load_embeddings() - vector_store = load_vector_store(self.kb_name, - embeddings=embeddings, - tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) + vector_store = self.load_vector_store() ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] if len(ids) == 0: @@ -148,14 +132,14 @@ class FaissKBService(KBService): vector_store.delete(ids) if not kwargs.get("not_refresh_vs_cache"): vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) + self.refresh_vs_cache() return True def do_clear_vs(self): shutil.rmtree(self.vs_path) os.makedirs(self.vs_path) - refresh_vs_cache(self.kb_name) + self.refresh_vs_cache() def exist_doc(self, file_name: str): if super().exist_doc(file_name): @@ -166,10 +150,11 @@ class FaissKBService(KBService): return "in_folder" else: return False -if __name__ == '__main__': - milvusService = FaissKBService("test") - milvusService.add_doc(KnowledgeFile("README.md", "test")) - milvusService.delete_doc(KnowledgeFile("README.md", "test")) - milvusService.do_drop_kb() - print(milvusService.search_docs("如何启动api服务")) \ No newline at end of file + +if __name__ == '__main__': + faissService = FaissKBService("test") + faissService.add_doc(KnowledgeFile("README.md", "test")) + faissService.delete_doc(KnowledgeFile("README.md", "test")) + faissService.do_drop_kb() + print(faissService.search_docs("如何启动api服务")) \ No newline at end of file diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 296ae44..dc916c9 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -47,24 +47,15 @@ class MilvusKBService(KBService): self._load_milvus() def do_drop_kb(self): - self.milvus.col.drop() + if self.milvus.col: + self.milvus.col.drop() def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings)) return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) - def add_doc(self, kb_file: KnowledgeFile, **kwargs): - """ - 向知识库添加文件 - """ - docs = kb_file.file2text() - self.milvus.add_documents(docs) - from server.db.repository.knowledge_file_repository import add_doc_to_db - status = add_doc_to_db(kb_file) - return status - def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs): - pass + self.milvus.add_documents(docs) def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): filepath = kb_file.filepath.replace('\\', '\\\\') diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 31cc908..906818f 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -47,23 +47,12 @@ class PGKBService(KBService): connect.commit() def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): - # todo: support score threshold self._load_pg_vector(embeddings=embeddings) return score_threshold_process(score_threshold, top_k, self.pg_vector.similarity_search_with_score(query, top_k)) - def add_doc(self, kb_file: KnowledgeFile, **kwargs): - """ - 向知识库添加文件 - """ - docs = kb_file.file2text() - self.pg_vector.add_documents(docs) - from server.db.repository.knowledge_file_repository import add_doc_to_db - status = add_doc_to_db(kb_file) - return status - def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs): - pass + self.pg_vector.add_documents(docs) def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): with self.pg_vector.connect() as connect: diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index c96d386..af506e2 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -1,10 +1,17 @@ from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE -from server.knowledge_base.utils import get_file_path, list_kbs_from_folder, list_docs_from_folder, KnowledgeFile -from server.knowledge_base.kb_service.base import KBServiceFactory -from server.db.repository.knowledge_file_repository import add_doc_to_db +from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder, + list_files_from_folder, run_in_thread_pool, + files2docs_in_thread, + KnowledgeFile,) +from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType +from server.db.repository.knowledge_file_repository import add_file_to_db from server.db.base import Base, engine import os -from typing import Literal, Callable, Any +from concurrent.futures import ThreadPoolExecutor +from typing import Literal, Callable, Any, List + + +pool = ThreadPoolExecutor(os.cpu_count()) def create_tables(): @@ -16,13 +23,22 @@ def reset_tables(): create_tables() +def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]: + kb_files = [] + for file in files: + try: + kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name) + kb_files.append(kb_file) + except Exception as e: + print(f"{e},已跳过") + return kb_files + + def folder2db( kb_name: str, mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"], vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, embed_model: str = EMBEDDING_MODEL, - callback_before: Callable = None, - callback_after: Callable = None, ): ''' use existed files in local folder to populate database and/or vector store. @@ -36,70 +52,59 @@ def folder2db( kb.create_kb() if mode == "recreate_vs": + files_count = kb.count_files() + print(f"知识库 {kb_name} 中共有 {files_count} 个文档。\n即将清除向量库。") kb.clear_vs() - docs = list_docs_from_folder(kb_name) - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, kb_name) - if callable(callback_before): - callback_before(kb_file, i, docs) - if i == len(docs) - 1: - not_refresh_vs_cache = False - else: - not_refresh_vs_cache = True - kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - if callable(callback_after): - callback_after(kb_file, i, docs) - except Exception as e: - print(e) + files_count = kb.count_files() + print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。") + + kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name)) + for success, result in files2docs_in_thread(kb_files, pool=pool): + if success: + _, filename, docs = result + print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档") + kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) + kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True) + else: + print(result) + + if kb.vs_type() == SupportedVSType.FAISS: + kb.refresh_vs_cache() elif mode == "fill_info_only": - docs = list_docs_from_folder(kb_name) - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, kb_name) - if callable(callback_before): - callback_before(kb_file, i, docs) - add_doc_to_db(kb_file) - if callable(callback_after): - callback_after(kb_file, i, docs) - except Exception as e: - print(e) + files = list_files_from_folder(kb_name) + kb_files = file_to_kbfile(kb_name, files) + + for kb_file in kb_file: + add_file_to_db(kb_file) + print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库") elif mode == "update_in_db": - docs = kb.list_docs() - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, kb_name) - if callable(callback_before): - callback_before(kb_file, i, docs) - if i == len(docs) - 1: - not_refresh_vs_cache = False - else: - not_refresh_vs_cache = True - kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - if callable(callback_after): - callback_after(kb_file, i, docs) - except Exception as e: - print(e) + files = kb.list_files() + kb_files = file_to_kbfile(kb_name, files) + + for kb_file in kb_files: + kb.update_doc(kb_file, not_refresh_vs_cache=True) + + if kb.vs_type() == SupportedVSType.FAISS: + kb.refresh_vs_cache() elif mode == "increament": - db_docs = kb.list_docs() - folder_docs = list_docs_from_folder(kb_name) - docs = list(set(folder_docs) - set(db_docs)) - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, kb_name) - if callable(callback_before): - callback_before(kb_file, i, docs) - if i == len(docs) - 1: - not_refresh_vs_cache = False - else: - not_refresh_vs_cache = True - kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - if callable(callback_after): - callback_after(kb_file, i, docs) - except Exception as e: - print(e) + db_files = kb.list_files() + folder_files = list_files_from_folder(kb_name) + files = list(set(folder_files) - set(db_files)) + kb_files = file_to_kbfile(kb_name, files) + + for success, result in files2docs_in_thread(kb_files, pool=pool): + if success: + _, filename, docs = result + print(f"正在将 {kb_name}/{filename} 添加到向量库") + kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) + kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True) + else: + print(result) + + if kb.vs_type() == SupportedVSType.FAISS: + kb.refresh_vs_cache() else: - raise ValueError(f"unspported migrate mode: {mode}") + print(f"unspported migrate mode: {mode}") def recreate_all_vs( @@ -114,30 +119,31 @@ def recreate_all_vs( folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs) -def prune_db_docs(kb_name: str): +def prune_db_files(kb_name: str): ''' - delete docs in database that not existed in local folder. - it is used to delete database docs after user deleted some doc files in file browser + delete files in database that not existed in local folder. + it is used to delete database files after user deleted some doc files in file browser ''' kb = KBServiceFactory.get_service_by_name(kb_name) if kb.exists(): - docs_in_db = kb.list_docs() - docs_in_folder = list_docs_from_folder(kb_name) - docs = list(set(docs_in_db) - set(docs_in_folder)) - for doc in docs: - kb.delete_doc(KnowledgeFile(doc, kb_name)) - return docs + files_in_db = kb.list_files() + files_in_folder = list_files_from_folder(kb_name) + files = list(set(files_in_db) - set(files_in_folder)) + kb_files = file_to_kbfile(kb_name, files) + for kb_file in kb_files: + kb.delete_doc(kb_file) + return kb_files -def prune_folder_docs(kb_name: str): +def prune_folder_files(kb_name: str): ''' delete doc files in local folder that not existed in database. is is used to free local disk space by delete unused doc files. ''' kb = KBServiceFactory.get_service_by_name(kb_name) if kb.exists(): - docs_in_db = kb.list_docs() - docs_in_folder = list_docs_from_folder(kb_name) - docs = list(set(docs_in_folder) - set(docs_in_db)) - for doc in docs: - os.remove(get_file_path(kb_name, doc)) - return docs + files_in_db = kb.list_files() + files_in_folder = list_files_from_folder(kb_name) + files = list(set(files_in_folder) - set(files_in_db)) + for file in files: + os.remove(get_file_path(kb_name, file)) + return files diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 1b25293..8cab754 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -16,7 +16,22 @@ import langchain.document_loaders from langchain.docstore.document import Document from pathlib import Path import json -from typing import List, Union, Callable, Dict, Optional +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Union, Callable, Dict, Optional, Tuple, Generator + + +# make HuggingFaceEmbeddings hashable +def _embeddings_hash(self): + if isinstance(self, HuggingFaceEmbeddings): + return hash(self.model_name) + elif isinstance(self, HuggingFaceBgeEmbeddings): + return hash(self.model_name) + elif isinstance(self, OpenAIEmbeddings): + return hash(self.model) + +HuggingFaceEmbeddings.__hash__ = _embeddings_hash +OpenAIEmbeddings.__hash__ = _embeddings_hash +HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash def validate_kb_name(knowledge_base_id: str) -> bool: @@ -47,7 +62,7 @@ def list_kbs_from_folder(): if os.path.isdir(os.path.join(KB_ROOT_PATH, f))] -def list_docs_from_folder(kb_name: str): +def list_files_from_folder(kb_name: str): doc_path = get_doc_path(kb_name) return [file for file in os.listdir(doc_path) if os.path.isfile(os.path.join(doc_path, file))] @@ -175,8 +190,11 @@ class KnowledgeFile: # TODO: 增加依据文件格式匹配text_splitter self.text_splitter_name = None - def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE): - print(self.document_loader_name) + def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False): + if self.docs is not None and not refresh: + return self.docs + + print(f"{self.document_loader_name} used for {self.filepath}") try: document_loaders_module = importlib.import_module('langchain.document_loaders') DocumentLoader = getattr(document_loaders_module, self.document_loader_name) @@ -193,9 +211,9 @@ class KnowledgeFile: elif self.document_loader_name == "CustomJSONLoader": loader = DocumentLoader(self.filepath, text_content=False) elif self.document_loader_name == "UnstructuredMarkdownLoader": - loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single` + loader = DocumentLoader(self.filepath, mode="elements") elif self.document_loader_name == "UnstructuredHTMLLoader": - loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single` + loader = DocumentLoader(self.filepath, mode="elements") else: loader = DocumentLoader(self.filepath) @@ -231,4 +249,63 @@ class KnowledgeFile: print(docs[0]) if using_zh_title_enhance: docs = zh_title_enhance(docs) + self.docs = docs return docs + + def get_mtime(self): + return os.path.getmtime(self.filepath) + + def get_size(self): + return os.path.getsize(self.filepath) + + +def run_in_thread_pool( + func: Callable, + params: List[Dict] = [], + pool: ThreadPoolExecutor = None, +) -> Generator: + ''' + 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 + 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 + ''' + tasks = [] + if pool is None: + pool = ThreadPoolExecutor() + + for kwargs in params: + thread = pool.submit(func, **kwargs) + tasks.append(thread) + + for obj in as_completed(tasks): + yield obj.result() + + +def files2docs_in_thread( + files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], + pool: ThreadPoolExecutor = None, +) -> Generator: + ''' + 利用多线程批量将文件转化成langchain Document. + 生成器返回值为{(kb_name, file_name): docs} + ''' + def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]: + try: + return True, (file.kb_name, file.filename, file.file2text(**kwargs)) + except Exception as e: + return False, e + + kwargs_list = [] + for i, file in enumerate(files): + kwargs = {} + if isinstance(file, tuple) and len(file) >= 2: + files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1]) + elif isinstance(file, dict): + filename = file.pop("filename") + kb_name = file.pop("kb_name") + files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) + kwargs = file + kwargs["file"] = file + kwargs_list.append(kwargs) + + for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool): + yield result diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index 5a8b97d..56142fa 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -112,7 +112,7 @@ def test_upload_doc(api="/knowledge_base/upload_doc"): assert data["msg"] == f"成功上传文件 {name}" -def test_list_docs(api="/knowledge_base/list_docs"): +def test_list_files(api="/knowledge_base/list_files"): url = api_base_url + api print("\n获取知识库中文件列表:") r = requests.get(url, params={"knowledge_base_name": kb}) diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 7b6b35e..0889ca5 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -4,7 +4,7 @@ from st_aggrid import AgGrid, JsCode from st_aggrid.grid_options_builder import GridOptionsBuilder import pandas as pd from server.knowledge_base.utils import get_file_path, LOADER_DICT -from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details +from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from typing import Literal, Dict, Tuple from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE import os @@ -152,7 +152,7 @@ def knowledge_base_page(api: ApiRequest): # 知识库详情 # st.info("请选择文件,点击按钮进行操作。") - doc_details = pd.DataFrame(get_kb_doc_details(kb)) + doc_details = pd.DataFrame(get_kb_file_details(kb)) if not len(doc_details): st.info(f"知识库 `{kb}` 中暂无文件") else: @@ -160,7 +160,7 @@ def knowledge_base_page(api: ApiRequest): st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作") doc_details.drop(columns=["kb_name"], inplace=True) doc_details = doc_details[[ - "No", "file_name", "document_loader", "text_splitter", "in_folder", "in_db", + "No", "file_name", "document_loader", "docs_count", "in_folder", "in_db", ]] # doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×") # doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×") @@ -172,7 +172,8 @@ def knowledge_base_page(api: ApiRequest): # ("file_ext", "文档类型"): {}, # ("file_version", "文档版本"): {}, ("document_loader", "文档加载器"): {}, - ("text_splitter", "分词器"): {}, + ("docs_count", "文档数量"): {}, + # ("text_splitter", "分词器"): {}, # ("create_time", "创建时间"): {}, ("in_folder", "源文件"): {"cellRenderer": cell_renderer}, ("in_db", "向量库"): {"cellRenderer": cell_renderer}, diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 042cb6b..827cb30 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -494,18 +494,18 @@ class ApiRequest: no_remote_api: bool = None, ): ''' - 对应api.py/knowledge_base/list_docs接口 + 对应api.py/knowledge_base/list_files接口 ''' if no_remote_api is None: no_remote_api = self.no_remote_api if no_remote_api: - from server.knowledge_base.kb_doc_api import list_docs - response = run_async(list_docs(knowledge_base_name)) + from server.knowledge_base.kb_doc_api import list_files + response = run_async(list_files(knowledge_base_name)) return response.data else: response = self.get( - "/knowledge_base/list_docs", + "/knowledge_base/list_files", params={"knowledge_base_name": knowledge_base_name} ) data = self._check_httpx_json_response(response)