update faiss_kb_service.py

This commit is contained in:
imClumsyPanda 2023-08-09 22:57:36 +08:00
parent db29a2fea7
commit 1b70fb5f9b
5 changed files with 26 additions and 40 deletions

View File

@ -20,19 +20,20 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History(**h) if isinstance(h, dict) else h for h in history]
async def knowledge_base_chat_iterator(query: str,
kb: KBService,
top_k: int,
@ -69,7 +70,8 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token,
"docs": source_documents}, ensure_ascii=False)
"docs": source_documents},
ensure_ascii=False)
await task
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),

View File

@ -61,13 +61,13 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -109,7 +109,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token,
"docs": source_documents}, ensure_ascii=False)
"docs": source_documents},
ensure_ascii=False)
await task
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),

View File

@ -12,7 +12,7 @@ from typing import Union
async def list_docs(
knowledge_base_name: str = Body(..., examples=["samples"])
knowledge_base_name: str
):
if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[])
@ -61,7 +61,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name"]),
doc_name: str = Body(..., examples=["file_name.md"]),
delete_content: bool = Body(False),
):
if not validate_kb_name(knowledge_base_name):

View File

@ -17,9 +17,10 @@ import numpy as np
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
return hash(self.model_name)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
_VECTOR_STORE_TICKS = {}
@ -46,24 +47,6 @@ def refresh_vs_cache(kb_name: str):
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]):
overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values())
if not overlapping:
raise ValueError("ids do not exist in the current object")
_reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()}
index_to_delete = [_reversed_index[i] for i in ids]
vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
for _id in index_to_delete:
del vector_store.index_to_docstore_id[_id]
# Remove items from docstore.
overlapping2 = set(ids).intersection(vector_store.docstore._dict)
if not overlapping2:
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
for _id in ids:
vector_store.docstore._dict.pop(_id)
return vector_store
class FaissKBService(KBService):
vs_path: str
kb_path: str
@ -119,14 +102,14 @@ class FaissKBService(KBService):
refresh_vs_cache(self.kb_name)
def do_delete_doc(self,
kb_file: KnowledgeFile):
kb_file: KnowledgeFile):
embeddings = self._load_embeddings()
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings)
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
vector_store = delete_doc_from_faiss(vector_store, ids)
vector_store.delete(ids)
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
return True

View File

@ -79,5 +79,5 @@ class KnowledgeFile:
# TODO: 增加依据文件格式匹配text_splitter
TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name)
text_splitter = TextSplitter(chunk_size=500, chunk_overlap=200)
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=200)
return loader.load_and_split(text_splitter)