From 0f46185cfb640135f34f954b58c28b30c3661ab3 Mon Sep 17 00:00:00 2001 From: zqt996 <67185303+zqt996@users.noreply.github.com> Date: Mon, 7 Aug 2023 16:56:57 +0800 Subject: [PATCH 01/10] =?UTF-8?q?=E6=B7=BB=E5=8A=A0Milvus=E5=BA=93=20(#101?= =?UTF-8?q?1)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py.example | 11 +- server/knowledge_base/kb_service/base.py | 12 +- .../kb_service/default_kb_service.py | 4 +- .../kb_service/faiss_kb_service.py | 5 +- .../kb_service/milvus_kb_service.py | 122 ++++++++++-------- .../knowledge_base/knowledge_base_factory.py | 8 +- 6 files changed, 90 insertions(+), 72 deletions(-) diff --git a/configs/model_config.py.example b/configs/model_config.py.example index ecbabac..b73ccd9 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -280,10 +280,13 @@ BING_SUBSCRIPTION_KEY = "" kbs_config = { "faiss": { - }, "milvus": { - "milvus_host": "192.168.50.128", - "milvus_port": 19530 + "host": "127.0.0.1", + "port": "19530", + "user": "", + "password": "", + "secure": False, } -} \ No newline at end of file +} +DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") \ No newline at end of file diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 006ca81..c7fb24e 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -195,11 +195,9 @@ class KBService(ABC): def __init__(self, knowledge_base_name: str, - vector_store_type: str = "faiss", embed_model: str = EMBEDDING_MODEL, ): self.kb_name = knowledge_base_name - self.vs_type = vector_store_type self.embed_model = embed_model self.kb_path = get_kb_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name) @@ -212,7 +210,7 @@ class KBService(ABC): if not os.path.exists(self.doc_path): os.makedirs(self.doc_path) self.do_create_kb() - status = add_kb_to_db(self.kb_name, self.vs_type, self.embed_model) + status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model) return status def clear_vs(self): @@ -225,7 +223,7 @@ class KBService(ABC): """ 删除知识库 """ - self.do_remove_kb() + self.do_drop_kb() status = delete_kb_from_db(self.kb_name) return status @@ -245,7 +243,7 @@ class KBService(ABC): """ if os.path.exists(kb_file.filepath): os.remove(kb_file.filepath) - self.do_delete(kb_file) + self.do_delete_doc(kb_file) status = delete_file_from_db(kb_file) return status @@ -293,7 +291,7 @@ class KBService(ABC): pass @abstractmethod - def do_remove_kb(self): + def do_drop_kb(self): """ 删除知识库子类实自己逻辑 """ @@ -320,7 +318,7 @@ class KBService(ABC): pass @abstractmethod - def do_delete(self, + def do_delete_doc(self, kb_file: KnowledgeFile): """ 从知识库删除文档子类实自己逻辑 diff --git a/server/knowledge_base/kb_service/default_kb_service.py b/server/knowledge_base/kb_service/default_kb_service.py index 3a6e0a5..fb31859 100644 --- a/server/knowledge_base/kb_service/default_kb_service.py +++ b/server/knowledge_base/kb_service/default_kb_service.py @@ -11,7 +11,7 @@ class DefaultKBService(KBService): def do_init(self): pass - def do_remove_kbs(self): + def do_drop_kbs(self): pass def do_search(self): @@ -23,5 +23,5 @@ class DefaultKBService(KBService): def do_insert_one_knowledge(self): pass - def do_delete(self): + def do_delete_doc(self): pass diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 3141754..64ed9ed 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -76,7 +76,7 @@ class FaissKBService(KBService): if not os.path.exists(self.vs_path): os.makedirs(self.vs_path) - def do_remove_kb(self): + def do_drop_kb(self): shutil.rmtree(self.kb_path) def do_search(self, @@ -105,7 +105,7 @@ class FaissKBService(KBService): vector_store.save_local(self.vs_path) refresh_vs_cache(self.kb_name) - def do_delete(self, + def do_delete_doc(self, kb_file: KnowledgeFile): embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): @@ -113,7 +113,6 @@ class FaissKBService(KBService): ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] if len(ids) == 0: return None - print(len(ids)) vector_store = delete_doc_from_faiss(vector_store, ids) vector_store.save_local(self.vs_path) refresh_vs_cache(self.kb_name) diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 5d43633..7dcceb4 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -1,65 +1,79 @@ -from pymilvus import ( - connections, - utility, - FieldSchema, - CollectionSchema, - DataType, - Collection, -) +from typing import List -from server.knowledge_base.kb_service.base import KBService +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +from langchain.vectorstores import Milvus + +from configs.model_config import EMBEDDING_DEVICE, kbs_config +from server.knowledge_base import KnowledgeFile +from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings, add_doc_to_db -def get_collection(milvus_name): - return Collection(milvus_name) +class MilvusKBService(KBService): + milvus: Milvus + @staticmethod + def get_collection(milvus_name): + from pymilvus import Collection + return Collection(milvus_name) -def search(milvus_name, content, limit=3): - search_params = { - "metric_type": "L2", - "params": {"nprobe": 10}, - } - c = get_collection(milvus_name) - return c.search(content, "embeddings", search_params, limit=limit, output_fields=["random"]) - - -class MilvusKBService(): - milvus_host: str - milvus_port: int - dim: int - - def __init__(self, knowledge_base_name: str, vector_store_type: str, milvus_host="localhost", milvus_port=19530, - dim=8): - - super().__init__(knowledge_base_name, vector_store_type) - self.milvus_host = milvus_host - self.milvus_port = milvus_port - self.dim = dim - - def connect(self): - connections.connect("default", host=self.milvus_host, port=self.milvus_port) - - def create_collection(self, milvus_name): - fields = [ - FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=False), - FieldSchema(name="content", dtype=DataType.STRING), - FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=self.dim) - ] - schema = CollectionSchema(fields) - collection = Collection(milvus_name, schema) - index = { - "index_type": "IVF_FLAT", + @staticmethod + def search(milvus_name, content, limit=3): + search_params = { "metric_type": "L2", - "params": {"nlist": 128}, + "params": {"nprobe": 10}, } - collection.create_index("embeddings", index) - collection.load() - return collection + c = MilvusKBService.get_collection(milvus_name) + return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"]) - def insert_collection(self, milvus_name, content=[]): - get_collection(milvus_name).insert(dataset) + def do_create_kb(self): + pass + + def vs_type(self) -> str: + return SupportedVSType.MILVUS + + def _load_milvus(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None): + _embeddings = embeddings + if _embeddings is None: + _embeddings = load_embeddings(self.embed_model, embedding_device) + self.milvus = Milvus(embedding_function=_embeddings, + collection_name=self.kb_name, connection_args=kbs_config.get("milvus")) + + def do_init(self): + self._load_milvus() + + def do_drop_kb(self): + self.milvus.col.drop() + + def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: + self._load_milvus(embeddings=embeddings) + return self.milvus.similarity_search(query, top_k) + + def add_doc(self, kb_file: KnowledgeFile): + """ + 向知识库添加文件 + """ + docs = kb_file.file2text() + self.milvus.add_documents(docs) + status = add_doc_to_db(kb_file) + return status + + def do_add_doc(self, docs: List[Document], embeddings: Embeddings): + pass + + def do_delete_doc(self, kb_file: KnowledgeFile): + filepath = kb_file.filepath.replace('\\', '\\\\') + delete_list = [item.get("pk") for item in + self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])] + self.milvus.col.delete(expr=f'pk in {delete_list}') + + def do_clear_vs(self): + self.milvus.col.drop() if __name__ == '__main__': - milvusService = MilvusService(milvus_host='192.168.50.128') - milvusService.insert_collection(test,dataset) + milvusService = MilvusKBService("test") + # milvusService.add_doc(KnowledgeFile("test.pdf", "test")) + # milvusService.delete_doc(KnowledgeFile("test.pdf", "test")) + milvusService.do_drop_kb() + print(milvusService.search_docs("测试")) diff --git a/server/knowledge_base/knowledge_base_factory.py b/server/knowledge_base/knowledge_base_factory.py index 793074e..7d61748 100644 --- a/server/knowledge_base/knowledge_base_factory.py +++ b/server/knowledge_base/knowledge_base_factory.py @@ -1,5 +1,6 @@ from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db from server.knowledge_base.kb_service.default_kb_service import DefaultKBService +from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService class KBServiceFactory: @@ -11,6 +12,9 @@ class KBServiceFactory: if SupportedVSType.FAISS == vector_store_type: from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService return FaissKBService(kb_name) + elif SupportedVSType.MILVUS == vector_store_type: + from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService + return MilvusKBService(kb_name) elif SupportedVSType.DEFAULT == vector_store_type: return DefaultKBService(kb_name) @@ -28,9 +32,9 @@ class KBServiceFactory: if __name__ == '__main__': KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS) init_db() - KBService.create_kbs() + KBService.create_kb() KBService = KBServiceFactory.get_default() print(KBService.list_kbs()) KBService = KBServiceFactory.get_service_by_name("test") print(KBService.list_docs()) - KBService.drop_kbs() + KBService.drop_kb() From 44c713ef989ae640079a710d28cb48cd8a5685bd Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Mon, 7 Aug 2023 20:37:16 +0800 Subject: [PATCH 02/10] use KBServiceFactory to replace all the KnowledgeBase. make KBServiceFactory support embed_model parameter. rewrite api: recreate_vector_store. fix some bugs. --- requirements.txt | 3 ++ server/chat/knowledge_base_chat.py | 9 ++-- server/knowledge_base/__init__.py | 2 +- server/knowledge_base/kb_api.py | 18 ++++---- server/knowledge_base/kb_doc_api.py | 41 +++++++++++-------- server/knowledge_base/kb_service/base.py | 5 +++ .../kb_service/default_kb_service.py | 7 +++- .../kb_service/faiss_kb_service.py | 1 + .../knowledge_base/knowledge_base_factory.py | 22 ++++++---- 9 files changed, 68 insertions(+), 40 deletions(-) diff --git a/requirements.txt b/requirements.txt index 765cf9f..7ec558f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,6 @@ streamlit-option-menu streamlit-antd-components streamlit-chatbox>=1.1.6 httpx + +faiss-cpu +pymilvus==2.1.3 # requires milvus==2.1.3 diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 4eaf3fd..0a85df3 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -10,7 +10,8 @@ from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio from langchain.prompts import PromptTemplate -from server.knowledge_base.knowledge_base import KnowledgeBase +from server.knowledge_base.knowledge_base_factory import KBServiceFactory +from server.knowledge_base.kb_service.base import KBService import json @@ -18,12 +19,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), ): - if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) async def knowledge_base_chat_iterator(query: str, - kb: KnowledgeBase, + kb: KBService, top_k: int, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() diff --git a/server/knowledge_base/__init__.py b/server/knowledge_base/__init__.py index 14b7898..b64e921 100644 --- a/server/knowledge_base/__init__.py +++ b/server/knowledge_base/__init__.py @@ -1,4 +1,4 @@ from .kb_api import list_kbs, create_kb, delete_kb from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store -from .knowledge_base import KnowledgeBase from .knowledge_file import KnowledgeFile +from .knowledge_base_factory import KBServiceFactory diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index d5ef57b..9a92ea6 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -1,13 +1,14 @@ import urllib from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name -from server.knowledge_base.knowledge_base import KnowledgeBase +from server.knowledge_base.knowledge_base_factory import KBServiceFactory +from server.knowledge_base.kb_service.base import list_kbs_from_db from configs.model_config import EMBEDDING_MODEL async def list_kbs(): # Get List of Knowledge Base - return ListResponse(data=KnowledgeBase.list_kbs()) + return ListResponse(data=list_kbs_from_db()) async def create_kb(knowledge_base_name: str, @@ -19,11 +20,10 @@ async def create_kb(knowledge_base_name: str, return BaseResponse(code=403, msg="Don't attack me") if knowledge_base_name is None or knowledge_base_name.strip() == "": return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") - if KnowledgeBase.exists(knowledge_base_name): + + kb = KBServiceFactory.get_service(knowledge_base_name, "faiss") + if kb is not None: return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") - kb = KnowledgeBase(knowledge_base_name=knowledge_base_name, - vector_store_type=vector_store_type, - embed_model=embed_model) kb.create() return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") @@ -34,10 +34,12 @@ async def delete_kb(knowledge_base_name: str): return BaseResponse(code=403, msg="Don't attack me") knowledge_base_name = urllib.parse.unquote(knowledge_base_name) - if not KnowledgeBase.exists(knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + + if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - status = KnowledgeBase.delete(knowledge_base_name) + status = kb.drop_kb() if status: return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") else: diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 62de2d9..2ebe828 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -6,7 +6,9 @@ from server.knowledge_base.utils import (validate_kb_name) from fastapi.responses import StreamingResponse import json from server.knowledge_base.knowledge_file import KnowledgeFile -from server.knowledge_base.knowledge_base import KnowledgeBase +from server.knowledge_base.knowledge_base_factory import KBServiceFactory +from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder +from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache async def list_docs(knowledge_base_name: str): @@ -14,10 +16,11 @@ async def list_docs(knowledge_base_name: str): return ListResponse(code=403, msg="Don't attack me", data=[]) knowledge_base_name = urllib.parse.unquote(knowledge_base_name) - if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) else: - all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs() + all_doc_names = kb.list_docs() return ListResponse(data=all_doc_names) @@ -28,11 +31,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") - if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) - file_content = await file.read() # 读取上传文件的内容 kb_file = KnowledgeFile(filename=file.filename, @@ -63,10 +65,10 @@ async def delete_doc(knowledge_base_name: str, return BaseResponse(code=403, msg="Don't attack me") knowledge_base_name = urllib.parse.unquote(knowledge_base_name) - if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) if not kb.exist_doc(doc_name): return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") kb_file = KnowledgeFile(filename=doc_name, @@ -92,21 +94,26 @@ async def recreate_vector_store(knowledge_base_name: str): recreate vector store from the content. this is usefull when user can copy files to content folder directly instead of upload through network. ''' - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - async def output(kb: KnowledgeBase): - kb.recreate_vs() + async def output(kb): + kb.clear_vs() print(f"start to recreate vector store of {kb.kb_name}") - docs = kb.list_docs() + docs = list_docs_from_folder(knowledge_base_name) + print(docs) for i, filename in enumerate(docs): + yield json.dumps({ + "total": len(docs), + "finished": i, + "doc": filename, + }) kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb.kb_name) print(f"processing {kb_file.filepath} to vector store.") kb.add_doc(kb_file) - yield json.dumps({ - "total": len(docs), - "finished": i + 1, - "doc": filename, - }) + if kb.vs_type == SupportedVSType.FAISS: + refresh_vs_cache(knowledge_base_name) return StreamingResponse(output(kb), media_type="text/event-stream") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index c7fb24e..243ceda 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -14,6 +14,7 @@ import datetime from server.knowledge_base.utils import (get_kb_path, get_doc_path) from server.knowledge_base.knowledge_file import KnowledgeFile from typing import List +import os class SupportedVSType: @@ -125,6 +126,10 @@ def list_docs_from_db(kb_name): conn.close() return kbs +def list_docs_from_folder(kb_name: str): + doc_path = get_doc_path(kb_name) + return [file for file in os.listdir(doc_path) + if os.path.isfile(os.path.join(doc_path, file))] def add_doc_to_db(kb_file: KnowledgeFile): conn = sqlite3.connect(DB_ROOT_PATH) diff --git a/server/knowledge_base/kb_service/default_kb_service.py b/server/knowledge_base/kb_service/default_kb_service.py index fb31859..36ebfcf 100644 --- a/server/knowledge_base/kb_service/default_kb_service.py +++ b/server/knowledge_base/kb_service/default_kb_service.py @@ -5,13 +5,13 @@ class DefaultKBService(KBService): def vs_type(self) -> str: return "default" - def do_create_kbs(self): + def do_create_kb(self): pass def do_init(self): pass - def do_drop_kbs(self): + def do_drop_kb(self): pass def do_search(self): @@ -25,3 +25,6 @@ class DefaultKBService(KBService): def do_delete_doc(self): pass + + def kb_exists(self): + return False diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 64ed9ed..5e9e8b7 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -122,3 +122,4 @@ class FaissKBService(KBService): def do_clear_vs(self): shutil.rmtree(self.vs_path) + os.makedirs(self.vs_path) diff --git a/server/knowledge_base/knowledge_base_factory.py b/server/knowledge_base/knowledge_base_factory.py index 7d61748..0ae93c0 100644 --- a/server/knowledge_base/knowledge_base_factory.py +++ b/server/knowledge_base/knowledge_base_factory.py @@ -1,28 +1,34 @@ +from typing import Union from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db from server.knowledge_base.kb_service.default_kb_service import DefaultKBService from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService +from configs.model_config import EMBEDDING_MODEL class KBServiceFactory: @staticmethod def get_service(kb_name: str, - vector_store_type: SupportedVSType + vector_store_type: Union[str, SupportedVSType], + embed_model: str = EMBEDDING_MODEL, ) -> KBService: + if isinstance(vector_store_type, str): + vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) if SupportedVSType.FAISS == vector_store_type: from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService - return FaissKBService(kb_name) - elif SupportedVSType.MILVUS == vector_store_type: - from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService - return MilvusKBService(kb_name) - elif SupportedVSType.DEFAULT == vector_store_type: + return FaissKBService(kb_name, embed_model=embed_model) + # todo: Milvus has different init params + # elif SupportedVSType.MILVUS == vector_store_type: + # from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService + # return MilvusKBService(kb_name,) + elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. return DefaultKBService(kb_name) @staticmethod def get_service_by_name(kb_name: str ) -> KBService: - kb_name, vs_type = load_kb_from_db(kb_name) - return KBServiceFactory.get_service(kb_name, vs_type) + kb_name, vs_type, embed_model = load_kb_from_db(kb_name) + return KBServiceFactory.get_service(kb_name, vs_type, embed_model) @staticmethod def get_default(): From c981cdb0422ead18dbea683801f4ca9400df6a12 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Mon, 7 Aug 2023 20:41:32 +0800 Subject: [PATCH 03/10] update model_config.py.example --- configs/model_config.py.example | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/configs/model_config.py.example b/configs/model_config.py.example index b73ccd9..cd942d4 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -288,5 +288,4 @@ kbs_config = { "password": "", "secure": False, } -} -DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") \ No newline at end of file +} \ No newline at end of file From 035a199c326c3171cd6878c2fdde8d086846283a Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Mon, 7 Aug 2023 20:48:06 +0800 Subject: [PATCH 04/10] update kb_server.base --- server/knowledge_base/kb_service/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 243ceda..89f3c3e 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -14,7 +14,6 @@ import datetime from server.knowledge_base.utils import (get_kb_path, get_doc_path) from server.knowledge_base.knowledge_file import KnowledgeFile from typing import List -import os class SupportedVSType: From 08493bffbbe215401b2df4edc672d4dc6a62700a Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Mon, 7 Aug 2023 21:00:55 +0800 Subject: [PATCH 05/10] add a basic knowledge base management ui. --- .../kb_service/milvus_kb_service.py | 2 +- webui_pages/dialogue/dialogue.py | 2 +- webui_pages/knowledge_base/knowledge_base.py | 122 +++++++++++++++++- 3 files changed, 122 insertions(+), 4 deletions(-) diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 7dcceb4..c346d65 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -5,7 +5,7 @@ from langchain.schema import Document from langchain.vectorstores import Milvus from configs.model_config import EMBEDDING_DEVICE, kbs_config -from server.knowledge_base import KnowledgeFile +from server.knowledge_base.knowledge_file import KnowledgeFile from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings, add_doc_to_db diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index b421fb6..e03a539 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -118,7 +118,7 @@ def dialogue_page(api: ApiRequest): chat_box.update_msg(text, 0, streaming=False) now = datetime.now() - cols[0].download_button( + export_btn.download_button( "Export", "".join(chat_box.export2md(cur_chat_name)), file_name=f"{now:%Y-%m-%d %H.%M}_{cur_chat_name}.md", diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 119a4a2..f36bae4 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -1,6 +1,124 @@ +from pydoc import Helper import streamlit as st from webui_pages.utils import * +import streamlit_antd_components as sac +from st_aggrid import AgGrid +from st_aggrid.grid_options_builder import GridOptionsBuilder +import pandas as pd +from server.knowledge_base.utils import get_file_path +from streamlit_chatbox import * + + +SENTENCE_SIZE = 100 + def knowledge_base_page(api: ApiRequest): - st.write(123) - pass \ No newline at end of file + api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True) + chat_box = ChatBox(session_key="kb_messages") + + kb_list = api.list_knowledge_bases() + kb_docs = {} + for kb in kb_list: + kb_docs[kb] = api.list_kb_docs(kb) + + with st.sidebar: + def on_new_kb(): + if name := st.session_state.get("new_kb_name"): + if name in kb_list: + st.error(f"名为 {name} 的知识库已经存在!") + else: + ret = api.create_knowledge_base(name) + st.toast(ret["msg"]) + + def on_del_kb(): + if name := st.session_state.get("new_kb_name"): + if name in kb_list: + ret = api.delete_knowledge_base(name) + st.toast(ret["msg"]) + else: + st.error(f"名为 {name} 的知识库不存在!") + + cols = st.columns([2, 1, 1]) + new_kb_name = cols[0].text_input( + "新知识库名称", + placeholder="新知识库名称", + label_visibility="collapsed", + key="new_kb_name", + ) + cols[1].button("新建", on_click=on_new_kb, disabled=not bool(new_kb_name)) + cols[2].button("删除", on_click=on_del_kb, disabled=not bool(new_kb_name)) + + st.write("知识库:") + if kb_list: + try: + index = kb_list.index(st.session_state.get("cur_kb")) + except: + index = 0 + kb = sac.buttons( + kb_list, + index, + format_func=lambda x: f"{x} ({len(kb_docs[x])})", + ) + st.session_state["cur_kb"] = kb + sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) + files = st.file_uploader("上传知识文件", + ["docx", "txt", "md", "csv", "xlsx", "pdf"], + accept_multiple_files=True, + key="files", + ) + if st.button( + "添加文件到知识库", + help="请先上传文件,再点击添加", + use_container_width=True, + disabled=len(files)==0, + ): + for f in files: + ret = api.upload_kb_doc(f, kb) + if ret["code"] == 200: + st.toast(ret["msg"], icon="✔") + else: + st.toast(ret["msg"], icon="❌") + st.session_state.files = [] + + if st.button( + "重建知识库", + help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", + use_container_width=True, + disabled=True, + ): + progress = st.progress(0.0, "") + for d in api.recreate_vector_store(kb): + progress.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") + + if kb_list: + # 知识库详情 + st.subheader(f"知识库 {kb} 详情") + df = pd.DataFrame([[i + 1, x] for i, x in enumerate(kb_docs[kb])], columns=["No", "文档名称"]) + gb = GridOptionsBuilder.from_dataframe(df) + gb.configure_column("No", width=50) + gb.configure_selection() + + cols = st.columns([1, 2]) + + with cols[0]: + docs = AgGrid(df, gb.build()) + + with cols[1]: + cols = st.columns(3) + selected_rows = docs.get("selected_rows", []) + + cols = st.columns([2, 3, 2]) + if selected_rows: + file_name = selected_rows[0]["文档名称"] + file_path = get_file_path(kb, file_name) + with open(file_path, "rb") as fp: + cols[0].download_button("下载选中文档", fp, file_name=file_name) + else: + cols[0].download_button("下载选中文档", "", disabled=True) + if cols[2].button("删除选中文档!", type="primary"): + for row in selected_rows: + ret = api.delete_kb_doc(kb, row["文档名称"]) + st.toast(ret["msg"]) + st.experimental_rerun() + + st.write("本文档包含以下知识条目:(待定内容)") From 823eb06c5d5a74987c5c903554b9ca1f1b0dd435 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Mon, 7 Aug 2023 22:57:13 +0800 Subject: [PATCH 06/10] =?UTF-8?q?BaseChatOpenAIChain,=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=9F=BA=E7=A1=80=E7=9A=84ChatOpenAI=E5=AF=B9=E8=AF=9D?= =?UTF-8?q?=E7=9A=84Chain=E6=8E=A5=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat_openai_chain/chat_openai_chain.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/server/chat/chat_openai_chain/chat_openai_chain.py b/server/chat/chat_openai_chain/chat_openai_chain.py index 17f1866..a989ccd 100644 --- a/server/chat/chat_openai_chain/chat_openai_chain.py +++ b/server/chat/chat_openai_chain/chat_openai_chain.py @@ -4,10 +4,14 @@ from typing import Any, Dict, List, Optional from langchain.chains.base import Chain from langchain.schema import ( BaseMessage, - messages_from_dict, + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, LLMResult ) -from langchain.chat_models import ChatOpenAI +from langchain.chat_models import ChatOpenAI, _convert_dict_to_message from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -16,6 +20,18 @@ from langchain.callbacks.manager import ( from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto +def _convert_dict_to_message(_dict: dict) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + return AIMessage(content=_dict["content"]) + elif role == "system": + return SystemMessage(content=_dict["content"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]: """ 前端消息传输对象DTO转换为chat消息传输对象DTO @@ -25,7 +41,7 @@ def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[Bas messages = [] for message_datum in message_data: messages.append(message_datum.dict()) - return messages_from_dict(messages) + return _convert_dict_to_message(messages) class BaseChatOpenAIChain(Chain, ABC): From de8db40f4b2b3c0c1f4464e1412a6aef7145af27 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Mon, 7 Aug 2023 23:05:54 +0800 Subject: [PATCH 07/10] =?UTF-8?q?BaseChatOpenAIChain,=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=9F=BA=E7=A1=80=E7=9A=84ChatOpenAI=E5=AF=B9=E8=AF=9D?= =?UTF-8?q?=E7=9A=84Chain=E6=8E=A5=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/chat/chat_openai_chain/chat_openai_chain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/chat/chat_openai_chain/chat_openai_chain.py b/server/chat/chat_openai_chain/chat_openai_chain.py index a989ccd..7757f98 100644 --- a/server/chat/chat_openai_chain/chat_openai_chain.py +++ b/server/chat/chat_openai_chain/chat_openai_chain.py @@ -11,7 +11,7 @@ from langchain.schema import ( SystemMessage, LLMResult ) -from langchain.chat_models import ChatOpenAI, _convert_dict_to_message +from langchain.chat_models import ChatOpenAI from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, From 40b6d1f17816f2736ade1589ed7611d45711bf75 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Tue, 8 Aug 2023 12:03:33 +0800 Subject: [PATCH 08/10] add python-magic-bin to requirements on windows. or document loader failed at `from unstructured.partion.auto import partion` on windows --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 7ec558f..257a62d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ starlette~=0.27.0 numpy~=1.24.4 pydantic~=1.10.11 unstructured[all-docs] +python-magic-bin; sys_platform == 'win32' streamlit>=1.25.0 streamlit-option-menu From 0746272525516b4f27dfa5d5326cc3b716524196 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Tue, 8 Aug 2023 13:36:20 +0800 Subject: [PATCH 09/10] remove server/knowledge_base/knowledge_base.py --- server/knowledge_base/knowledge_base.py | 407 ------------------------ server/knowledge_base/utils.py | 15 +- 2 files changed, 12 insertions(+), 410 deletions(-) delete mode 100644 server/knowledge_base/knowledge_base.py diff --git a/server/knowledge_base/knowledge_base.py b/server/knowledge_base/knowledge_base.py deleted file mode 100644 index 10ccd9d..0000000 --- a/server/knowledge_base/knowledge_base.py +++ /dev/null @@ -1,407 +0,0 @@ -import os -import sqlite3 -import datetime -import shutil -from langchain.vectorstores import FAISS -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE, - DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM) -from server.utils import torch_gc -from functools import lru_cache -from server.knowledge_base.knowledge_file import KnowledgeFile -from typing import List -import numpy as np -from server.knowledge_base.utils import (get_kb_path, get_doc_path, get_vs_path) - -SUPPORTED_VS_TYPES = ["faiss", "milvus"] - -_VECTOR_STORE_TICKS = {} - -@lru_cache(1) -def load_embeddings(model: str, device: str): - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], - model_kwargs={'device': device}) - return embeddings - - -@lru_cache(CACHED_VS_NUM) -def load_vector_store( - knowledge_base_name: str, - embedding_model: str, - embedding_device: str, - tick: int, # tick will be changed by upload_doc etc. and make cache refreshed. -): - print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.") - embeddings = load_embeddings(embedding_model, embedding_device) - vs_path = get_vs_path(knowledge_base_name) - search_index = FAISS.load_local(vs_path, embeddings) - return search_index - - -def refresh_vs_cache(kb_name: str): - """ - make vector store cache refreshed when next loading - """ - _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 - - -def list_kbs_from_db(): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute('''CREATE TABLE if not exists knowledge_base - (id INTEGER PRIMARY KEY AUTOINCREMENT, - kb_name TEXT, - vs_type TEXT, - embed_model TEXT, - file_count INTEGER, - create_time DATETIME) ''') - c.execute(f'''SELECT kb_name - FROM knowledge_base - WHERE file_count>0 ''') - kbs = [i[0] for i in c.fetchall() if i] - conn.commit() - conn.close() - return kbs - - -def add_kb_to_db(kb_name, vs_type, embed_model): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - # Create table - c.execute('''CREATE TABLE if not exists knowledge_base - (id INTEGER PRIMARY KEY AUTOINCREMENT, - kb_name TEXT, - vs_type TEXT, - embed_model TEXT, - file_count INTEGER, - create_time DATETIME) ''') - # Insert a row of data - c.execute(f"""INSERT INTO knowledge_base - (kb_name, vs_type, embed_model, file_count, create_time) - VALUES - ('{kb_name}','{vs_type}','{embed_model}', - 0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") - conn.commit() - conn.close() - - -def kb_exists(kb_name): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute('''CREATE TABLE if not exists knowledge_base - (id INTEGER PRIMARY KEY AUTOINCREMENT, - kb_name TEXT, - vs_type TEXT, - embed_model TEXT, - file_count INTEGER, - create_time DATETIME) ''') - c.execute(f'''SELECT COUNT(*) - FROM knowledge_base - WHERE kb_name="{kb_name}" ''') - status = True if c.fetchone()[0] else False - conn.commit() - conn.close() - return status - - -def load_kb_from_db(kb_name): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute('''CREATE TABLE if not exists knowledge_base - (id INTEGER PRIMARY KEY AUTOINCREMENT, - kb_name TEXT, - vs_type TEXT, - embed_model TEXT, - file_count INTEGER, - create_time DATETIME) ''') - c.execute(f'''SELECT kb_name, vs_type, embed_model - FROM knowledge_base - WHERE kb_name="{kb_name}" ''') - resp = c.fetchone() - if resp: - kb_name, vs_type, embed_model = resp - else: - kb_name, vs_type, embed_model = None, None, None - conn.commit() - conn.close() - return kb_name, vs_type, embed_model - - -def delete_kb_from_db(kb_name): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - # delete kb from table knowledge_base - c.execute('''CREATE TABLE if not exists knowledge_base - (id INTEGER PRIMARY KEY AUTOINCREMENT, - kb_name TEXT, - vs_type TEXT, - embed_model TEXT, - file_count INTEGER, - create_time DATETIME) ''') - c.execute(f'''DELETE - FROM knowledge_base - WHERE kb_name="{kb_name}" ''') - # delete files in kb from table knowledge_files - c.execute('''CREATE TABLE if not exists knowledge_files - (id INTEGER PRIMARY KEY AUTOINCREMENT, - file_name TEXT, - file_ext TEXT, - kb_name TEXT, - document_loader_name TEXT, - text_splitter_name TEXT, - file_version INTEGER, - create_time DATETIME) ''') - # Insert a row of data - c.execute(f"""DELETE - FROM knowledge_files - WHERE kb_name="{kb_name}" - """) - conn.commit() - conn.close() - return True - - -def list_docs_from_db(kb_name): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute('''CREATE TABLE if not exists knowledge_files - (id INTEGER PRIMARY KEY AUTOINCREMENT, - file_name TEXT, - file_ext TEXT, - kb_name TEXT, - document_loader_name TEXT, - text_splitter_name TEXT, - file_version INTEGER, - create_time DATETIME) ''') - c.execute(f'''SELECT file_name - FROM knowledge_files - WHERE kb_name="{kb_name}" ''') - kbs = [i[0] for i in c.fetchall() if i] - conn.commit() - conn.close() - return kbs - - -def add_doc_to_db(kb_file: KnowledgeFile): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - # Create table - c.execute('''CREATE TABLE if not exists knowledge_files - (id INTEGER PRIMARY KEY AUTOINCREMENT, - file_name TEXT, - file_ext TEXT, - kb_name TEXT, - document_loader_name TEXT, - text_splitter_name TEXT, - file_version INTEGER, - create_time DATETIME) ''') - # Insert a row of data - # TODO: 同名文件添加至知识库时,file_version增加 - c.execute(f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """) - record_exist = c.fetchone() - if record_exist is not None: - c.execute(f"""UPDATE knowledge_files - SET file_version = file_version + 1 - WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" - """) - else: - c.execute(f"""INSERT INTO knowledge_files - (file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time) - VALUES - ('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}', - '{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") - conn.commit() - conn.close() - - -def delete_file_from_db(kb_file: KnowledgeFile): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - # delete files in kb from table knowledge_files - c.execute('''CREATE TABLE if not exists knowledge_files - (id INTEGER PRIMARY KEY AUTOINCREMENT, - file_name TEXT, - file_ext TEXT, - kb_name TEXT, - document_loader_name TEXT, - text_splitter_name TEXT, - file_version INTEGER, - create_time DATETIME) ''') - # Insert a row of data - c.execute(f"""DELETE - FROM knowledge_files - WHERE file_name="{kb_file.filename}" - AND kb_name="{kb_file.kb_name}" - """) - conn.commit() - conn.close() - return True - - -def doc_exists(kb_file: KnowledgeFile): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute('''CREATE TABLE if not exists knowledge_files - (id INTEGER PRIMARY KEY AUTOINCREMENT, - file_name TEXT, - file_ext TEXT, - kb_name TEXT, - document_loader_name TEXT, - text_splitter_name TEXT, - file_version INTEGER, - create_time DATETIME) ''') - c.execute(f'''SELECT COUNT(*) - FROM knowledge_files - WHERE file_name="{kb_file.filename}" - AND kb_name="{kb_file.kb_name}" ''') - status = True if c.fetchone()[0] else False - conn.commit() - conn.close() - return status - - -def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]): - overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values()) - if not overlapping: - raise ValueError("ids do not exist in the current object") - _reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()} - index_to_delete = [_reversed_index[i] for i in ids] - vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64)) - for _id in index_to_delete: - del vector_store.index_to_docstore_id[_id] - # Remove items from docstore. - overlapping2 = set(ids).intersection(vector_store.docstore._dict) - if not overlapping2: - raise ValueError(f"Tried to delete ids that does not exist: {ids}") - for _id in ids: - vector_store.docstore._dict.pop(_id) - return vector_store - - -class KnowledgeBase: - def __init__(self, - knowledge_base_name: str, - vector_store_type: str = "faiss", - embed_model: str = EMBEDDING_MODEL, - ): - self.kb_name = knowledge_base_name - if vector_store_type not in SUPPORTED_VS_TYPES: - raise ValueError(f"暂未支持向量库类型 {vector_store_type}") - self.vs_type = vector_store_type - if embed_model not in embedding_model_dict.keys(): - raise ValueError(f"暂未支持embedding模型 {embed_model}") - self.embed_model = embed_model - self.kb_path = get_kb_path(self.kb_name) - self.doc_path = get_doc_path(self.kb_name) - if self.vs_type in ["faiss"]: - self.vs_path = get_vs_path(self.kb_name) - elif self.vs_type in ["milvus"]: - pass - - def create(self): - if not os.path.exists(self.doc_path): - os.makedirs(self.doc_path) - if self.vs_type in ["faiss"]: - if not os.path.exists(self.vs_path): - os.makedirs(self.vs_path) - add_kb_to_db(self.kb_name, self.vs_type, self.embed_model) - elif self.vs_type in ["milvus"]: - # TODO: 创建milvus库 - pass - return True - - def recreate_vs(self): - if self.vs_type in ["faiss"]: - shutil.rmtree(self.vs_path) - self.create() - - def add_doc(self, kb_file: KnowledgeFile): - docs = kb_file.file2text() - embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) - if self.vs_type in ["faiss"]: - if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): - vector_store = FAISS.load_local(self.vs_path, embeddings) - vector_store.add_documents(docs) - torch_gc() - else: - if not os.path.exists(self.vs_path): - os.makedirs(self.vs_path) - vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表 - torch_gc() - vector_store.save_local(self.vs_path) - add_doc_to_db(kb_file) - refresh_vs_cache(self.kb_name) - elif self.vs_type in ["milvus"]: - # TODO: 向milvus库中增加文件 - pass - - def delete_doc(self, kb_file: KnowledgeFile): - if os.path.exists(kb_file.filepath): - os.remove(kb_file.filepath) - if self.vs_type in ["faiss"]: - # TODO: 从FAISS向量库中删除文档 - embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) - if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): - vector_store = FAISS.load_local(self.vs_path, embeddings) - ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] - if len(ids) == 0: - return None - print(len(ids)) - vector_store = delete_doc_from_faiss(vector_store, ids) - vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) - delete_file_from_db(kb_file) - return True - - def exist_doc(self, file_name: str): - return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, - filename=file_name)) - - def list_docs(self): - return list_docs_from_db(self.kb_name) - - def search_docs(self, - query: str, - top_k: int = VECTOR_SEARCH_TOP_K, - embedding_device: str = EMBEDDING_DEVICE, ): - search_index = load_vector_store(self.kb_name, - self.embed_model, - embedding_device, - _VECTOR_STORE_TICKS.get(self.kb_name)) - docs = search_index.similarity_search(query, k=top_k) - return docs - - @classmethod - def exists(cls, - knowledge_base_name: str): - return kb_exists(knowledge_base_name) - - @classmethod - def load(cls, - knowledge_base_name: str): - kb_name, vs_type, embed_model = load_kb_from_db(knowledge_base_name) - return cls(kb_name, vs_type, embed_model) - - @classmethod - def delete(cls, - knowledge_base_name: str): - kb = cls.load(knowledge_base_name) - if kb.vs_type in ["faiss"]: - shutil.rmtree(kb.kb_path) - elif kb.vs_type in ["milvus"]: - # TODO: 删除milvus库 - pass - status = delete_kb_from_db(knowledge_base_name) - return status - - @classmethod - def list_kbs(cls): - return list_kbs_from_db() - - -if __name__ == "__main__": - # kb = KnowledgeBase("123", "faiss") - # kb.create() - kb = KnowledgeBase.load(knowledge_base_name="123") - kb.delete_doc(KnowledgeFile(knowledge_base_name="123", filename="README.md")) - print() diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 215ab36..fa6f132 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,5 +1,8 @@ -import os.path -from configs.model_config import KB_ROOT_PATH +import os +from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from configs.model_config import (embedding_model_dict, KB_ROOT_PATH) +from functools import lru_cache + def validate_kb_name(knowledge_base_id: str) -> bool: # 检查是否包含预期外的字符或路径攻击关键字 @@ -17,4 +20,10 @@ def get_vs_path(knowledge_base_name: str): return os.path.join(get_kb_path(knowledge_base_name), "vector_store") def get_file_path(knowledge_base_name: str, doc_name: str): - return os.path.join(get_doc_path(knowledge_base_name), doc_name) \ No newline at end of file + return os.path.join(get_doc_path(knowledge_base_name), doc_name) + +@lru_cache(1) +def load_embeddings(model: str, device: str): + embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], + model_kwargs={'device': device}) + return embeddings From 360bd0a559af477c7a226d5941c44b539dc8409b Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Tue, 8 Aug 2023 13:39:19 +0800 Subject: [PATCH 10/10] update llm_api_sh.py --- server/llm_api_sh.py | 148 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 147 insertions(+), 1 deletion(-) diff --git a/server/llm_api_sh.py b/server/llm_api_sh.py index 904ac71..3a8e880 100644 --- a/server/llm_api_sh.py +++ b/server/llm_api_sh.py @@ -7,11 +7,157 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import LOG_PATH,controller_args,worker_args,server_args,parser + import subprocess import re +import logging import argparse +LOG_PATH = "./logs/" +LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" +logger = logging.getLogger() +logger.setLevel(logging.INFO) +logging.basicConfig(format=LOG_FORMAT) + + +parser = argparse.ArgumentParser() +#------multi worker----------------- +parser.add_argument('--model-path-address', + default="THUDM/chatglm2-6b@localhost@20002", + nargs="+", + type=str, + help="model path, host, and port, formatted as model-path@host@path") +#---------------controller------------------------- + +parser.add_argument("--controller-host", type=str, default="localhost") +parser.add_argument("--controller-port", type=int, default=21001) +parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", +) +controller_args = ["controller-host","controller-port","dispatch-method"] + +#----------------------worker------------------------------------------ + +parser.add_argument("--worker-host", type=str, default="localhost") +parser.add_argument("--worker-port", type=int, default=21002) +# parser.add_argument("--worker-address", type=str, default="http://localhost:21002") +# parser.add_argument( +# "--controller-address", type=str, default="http://localhost:21001" +# ) +parser.add_argument( + "--model-path", + type=str, + default="lmsys/vicuna-7b-v1.3", + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", +) +parser.add_argument( + "--revision", + type=str, + default="main", + help="Hugging Face Hub model revision identifier", +) +parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda", "mps", "xpu"], + default="cuda", + help="The device type", +) +parser.add_argument( + "--gpus", + type=str, + default="0", + help="A single GPU like 1 or multiple GPUs like 0,2", +) +parser.add_argument("--num-gpus", type=int, default=1) +parser.add_argument( + "--max-gpu-memory", + type=str, + help="The maximum memory per gpu. Use a string like '13Gib'", +) +parser.add_argument( + "--load-8bit", action="store_true", help="Use 8-bit quantization" +) +parser.add_argument( + "--cpu-offloading", + action="store_true", + help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", +) +parser.add_argument( + "--gptq-ckpt", + type=str, + default=None, + help="Load quantized model. The path to the local GPTQ checkpoint.", +) +parser.add_argument( + "--gptq-wbits", + type=int, + default=16, + choices=[2, 3, 4, 8, 16], + help="#bits to use for quantization", +) +parser.add_argument( + "--gptq-groupsize", + type=int, + default=-1, + help="Groupsize to use for quantization; default uses full row.", +) +parser.add_argument( + "--gptq-act-order", + action="store_true", + help="Whether to apply the activation order GPTQ heuristic", +) +parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", +) +parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", +) +parser.add_argument("--stream-interval", type=int, default=2) +parser.add_argument("--no-register", action="store_true") + +worker_args = [ + "worker-host","worker-port", + "model-path","revision","device","gpus","num-gpus", + "max-gpu-memory","load-8bit","cpu-offloading", + "gptq-ckpt","gptq-wbits","gptq-groupsize", + "gptq-act-order","model-names","limit-worker-concurrency", + "stream-interval","no-register", + "controller-address" + ] +#-----------------openai server--------------------------- + +parser.add_argument("--server-host", type=str, default="localhost", help="host name") +parser.add_argument("--server-port", type=int, default=8001, help="port number") +parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" +) +# parser.add_argument( +# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" +# ) +# parser.add_argument( +# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" +# ) +# parser.add_argument( +# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" +# ) +parser.add_argument( + "--api-keys", + type=lambda s: s.split(","), + help="Optional list of comma separated API keys", +) +server_args = ["server-host","server-port","allow-credentials","api-keys", + "controller-address" + ] + args = parser.parse_args() # 必须要加http//:,否则InvalidSchema: No connection adapters were found args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"})