From 1b70fb5f9b7de5d51cc13d5ef50d2dd46975d4fa Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Wed, 9 Aug 2023 22:57:36 +0800 Subject: [PATCH] update faiss_kb_service.py --- server/chat/knowledge_base_chat.py | 18 +++++++------ server/chat/search_engine_chat.py | 17 +++++++------ server/knowledge_base/kb_doc_api.py | 4 +-- .../kb_service/faiss_kb_service.py | 25 +++---------------- server/knowledge_base/utils.py | 2 +- 5 files changed, 26 insertions(+), 40 deletions(-) diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index b0ed287..6746750 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -20,19 +20,20 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), + description="历史对话", + examples=[[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}]] + ), ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") history = [History(**h) if isinstance(h, dict) else h for h in history] + async def knowledge_base_chat_iterator(query: str, kb: KBService, top_k: int, @@ -69,7 +70,8 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp async for token in callback.aiter(): # Use server-sent-events to stream the response yield json.dumps({"answer": token, - "docs": source_documents}, ensure_ascii=False) + "docs": source_documents}, + ensure_ascii=False) await task return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history), diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 56f1db0..95b2bd0 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -61,13 +61,13 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), + description="历史对话", + examples=[[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}]] + ), ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -109,7 +109,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl async for token in callback.aiter(): # Use server-sent-events to stream the response yield json.dumps({"answer": token, - "docs": source_documents}, ensure_ascii=False) + "docs": source_documents}, + ensure_ascii=False) await task return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history), diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 23f8ed0..bfdfd87 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -12,7 +12,7 @@ from typing import Union async def list_docs( - knowledge_base_name: str = Body(..., examples=["samples"]) + knowledge_base_name: str ): if not validate_kb_name(knowledge_base_name): return ListResponse(code=403, msg="Don't attack me", data=[]) @@ -61,7 +61,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), - doc_name: str = Body(..., examples=["file_name"]), + doc_name: str = Body(..., examples=["file_name.md"]), delete_content: bool = Body(False), ): if not validate_kb_name(knowledge_base_name): diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 87eca2b..eafa98e 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -17,9 +17,10 @@ import numpy as np # make HuggingFaceEmbeddings hashable def _embeddings_hash(self): return hash(self.model_name) -HuggingFaceEmbeddings.__hash__ = _embeddings_hash +HuggingFaceEmbeddings.__hash__ = _embeddings_hash + _VECTOR_STORE_TICKS = {} @@ -46,24 +47,6 @@ def refresh_vs_cache(kb_name: str): _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 -def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]): - overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values()) - if not overlapping: - raise ValueError("ids do not exist in the current object") - _reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()} - index_to_delete = [_reversed_index[i] for i in ids] - vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64)) - for _id in index_to_delete: - del vector_store.index_to_docstore_id[_id] - # Remove items from docstore. - overlapping2 = set(ids).intersection(vector_store.docstore._dict) - if not overlapping2: - raise ValueError(f"Tried to delete ids that does not exist: {ids}") - for _id in ids: - vector_store.docstore._dict.pop(_id) - return vector_store - - class FaissKBService(KBService): vs_path: str kb_path: str @@ -119,14 +102,14 @@ class FaissKBService(KBService): refresh_vs_cache(self.kb_name) def do_delete_doc(self, - kb_file: KnowledgeFile): + kb_file: KnowledgeFile): embeddings = self._load_embeddings() if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): vector_store = FAISS.load_local(self.vs_path, embeddings) ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] if len(ids) == 0: return None - vector_store = delete_doc_from_faiss(vector_store, ids) + vector_store.delete(ids) vector_store.save_local(self.vs_path) refresh_vs_cache(self.kb_name) return True diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 5db5097..6d52360 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -79,5 +79,5 @@ class KnowledgeFile: # TODO: 增加依据文件格式匹配text_splitter TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name) - text_splitter = TextSplitter(chunk_size=500, chunk_overlap=200) + text_splitter = TextSplitter(chunk_size=250, chunk_overlap=200) return loader.load_and_split(text_splitter)