update faiss_kb_service.py
This commit is contained in:
parent
db29a2fea7
commit
1b70fb5f9b
|
|
@ -20,19 +20,20 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
|
||||
async def knowledge_base_chat_iterator(query: str,
|
||||
kb: KBService,
|
||||
top_k: int,
|
||||
|
|
@ -69,7 +70,8 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield json.dumps({"answer": token,
|
||||
"docs": source_documents}, ensure_ascii=False)
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),
|
||||
|
|
|
|||
|
|
@ -61,13 +61,13 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
|
|
@ -109,7 +109,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield json.dumps({"answer": token,
|
||||
"docs": source_documents}, ensure_ascii=False)
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import Union
|
|||
|
||||
|
||||
async def list_docs(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"])
|
||||
knowledge_base_name: str
|
||||
):
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||
|
|
@ -61,7 +61,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||
|
||||
|
||||
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
doc_name: str = Body(..., examples=["file_name"]),
|
||||
doc_name: str = Body(..., examples=["file_name.md"]),
|
||||
delete_content: bool = Body(False),
|
||||
):
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
|
|
|
|||
|
|
@ -17,9 +17,10 @@ import numpy as np
|
|||
# make HuggingFaceEmbeddings hashable
|
||||
def _embeddings_hash(self):
|
||||
return hash(self.model_name)
|
||||
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
|
||||
|
||||
|
||||
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
|
||||
|
||||
|
|
@ -46,24 +47,6 @@ def refresh_vs_cache(kb_name: str):
|
|||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||
|
||||
|
||||
def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]):
|
||||
overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values())
|
||||
if not overlapping:
|
||||
raise ValueError("ids do not exist in the current object")
|
||||
_reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()}
|
||||
index_to_delete = [_reversed_index[i] for i in ids]
|
||||
vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
|
||||
for _id in index_to_delete:
|
||||
del vector_store.index_to_docstore_id[_id]
|
||||
# Remove items from docstore.
|
||||
overlapping2 = set(ids).intersection(vector_store.docstore._dict)
|
||||
if not overlapping2:
|
||||
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
|
||||
for _id in ids:
|
||||
vector_store.docstore._dict.pop(_id)
|
||||
return vector_store
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
vs_path: str
|
||||
kb_path: str
|
||||
|
|
@ -119,14 +102,14 @@ class FaissKBService(KBService):
|
|||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile):
|
||||
kb_file: KnowledgeFile):
|
||||
embeddings = self._load_embeddings()
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
vector_store = delete_doc_from_faiss(vector_store, ids)
|
||||
vector_store.delete(ids)
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -79,5 +79,5 @@ class KnowledgeFile:
|
|||
|
||||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name)
|
||||
text_splitter = TextSplitter(chunk_size=500, chunk_overlap=200)
|
||||
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=200)
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
|
|
|||
Loading…
Reference in New Issue