From 8df00d0b042ad292e57121a6578c62c61769db8b Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Thu, 10 Aug 2023 00:36:51 +0800 Subject: [PATCH] add SCORE_THRESHOLD to faiss and milvus kb service --- configs/model_config.py.example | 3 +++ server/api.py | 2 +- .../kb_service/faiss_kb_service.py | 19 +++++++++------ .../kb_service/milvus_kb_service.py | 4 ++-- server/knowledge_base/utils.py | 24 +++++++++++++++---- 5 files changed, 38 insertions(+), 14 deletions(-) diff --git a/configs/model_config.py.example b/configs/model_config.py.example index e26fa50..7f97a7e 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -116,6 +116,9 @@ OVERLAP_SIZE = 50 # 知识库匹配向量数量 VECTOR_SEARCH_TOP_K = 5 +# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右 +SCORE_THRESHOLD = 1 + # 搜索引擎匹配结题数量 SEARCH_ENGINE_TOP_K = 5 diff --git a/server/api.py b/server/api.py index 8a7a2c1..900ba00 100644 --- a/server/api.py +++ b/server/api.py @@ -91,7 +91,7 @@ def create_app(): app.delete("/knowledge_base/delete_doc", tags=["Knowledge Base Management"], response_model=BaseResponse, - summary="删除知识库内的文件" + summary="删除知识库内指定文件" )(delete_doc) app.post("/knowledge_base/update_doc", diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index eafa98e..0ef820a 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -1,7 +1,13 @@ import os import shutil -from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, EMBEDDING_DEVICE +from configs.model_config import ( + KB_ROOT_PATH, + CACHED_VS_NUM, + EMBEDDING_MODEL, + EMBEDDING_DEVICE, + SCORE_THRESHOLD +) from server.knowledge_base.kb_service.base import KBService, SupportedVSType from functools import lru_cache from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile @@ -11,7 +17,6 @@ from langchain.embeddings.huggingface import HuggingFaceEmbeddings from typing import List from langchain.docstore.document import Document from server.utils import torch_gc -import numpy as np # make HuggingFaceEmbeddings hashable @@ -36,7 +41,7 @@ def load_vector_store( vs_path = get_vs_path(knowledge_base_name) if embeddings is None: embeddings = load_embeddings(embed_model, embed_device) - search_index = FAISS.load_local(vs_path, embeddings) + search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True) return search_index @@ -81,7 +86,7 @@ class FaissKBService(KBService): search_index = load_vector_store(self.kb_name, embeddings=embeddings, tick=_VECTOR_STORE_TICKS.get(self.kb_name)) - docs = search_index.similarity_search(query, k=top_k) + docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD) return docs def do_add_doc(self, @@ -89,14 +94,14 @@ class FaissKBService(KBService): embeddings: Embeddings, ): if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): - vector_store = FAISS.load_local(self.vs_path, embeddings) + vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True) vector_store.add_documents(docs) torch_gc() else: if not os.path.exists(self.vs_path): os.makedirs(self.vs_path) vector_store = FAISS.from_documents( - docs, embeddings) # docs 为Document列表 + docs, embeddings, normalize_L2=True) # docs 为Document列表 torch_gc() vector_store.save_local(self.vs_path) refresh_vs_cache(self.kb_name) @@ -105,7 +110,7 @@ class FaissKBService(KBService): kb_file: KnowledgeFile): embeddings = self._load_embeddings() if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): - vector_store = FAISS.load_local(self.vs_path, embeddings) + vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True) ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] if len(ids) == 0: return None diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 8898a0e..9f4dc60 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -4,7 +4,7 @@ from langchain.embeddings.base import Embeddings from langchain.schema import Document from langchain.vectorstores import Milvus -from configs.model_config import EMBEDDING_DEVICE, kbs_config +from configs.model_config import SCORE_THRESHOLD, kbs_config from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.utils import KnowledgeFile @@ -47,7 +47,7 @@ class MilvusKBService(KBService): def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: self._load_milvus(embeddings=embeddings) - return self.milvus.similarity_search(query, top_k) + return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD) def add_doc(self, kb_file: KnowledgeFile): """ diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index dbe6886..004c748 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,9 +1,16 @@ import os from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE) +from configs.model_config import ( + embedding_model_dict, + KB_ROOT_PATH, + CHUNK_SIZE, + OVERLAP_SIZE, + ZH_TITLE_ENHANCE +) from functools import lru_cache import sys from text_splitter import zh_title_enhance +from langchain.document_loaders import UnstructuredFileLoader def validate_kb_name(knowledge_base_id: str) -> bool: @@ -74,9 +81,17 @@ class KnowledgeFile: # TODO: 增加依据文件格式匹配text_splitter self.text_splitter_name = None - def file2text(self, using_zh_title_enhance): - DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name) - loader = DocumentLoader(self.filepath) + def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE): + print(self.document_loader_name) + try: + DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name) + except Exception as e: + print(e) + DocumentLoader = getattr(sys.modules['langchain.document_loaders'], "UnstructuredFileLoader") + if self.document_loader_name == "UnstructuredFileLoader": + loader = DocumentLoader(self.filepath, autodetect_encoding=True) + else: + loader = DocumentLoader(self.filepath) # TODO: 增加依据文件格式匹配text_splitter try: @@ -101,6 +116,7 @@ class KnowledgeFile: ) docs = loader.load_and_split(text_splitter) + print(docs[0]) if using_zh_title_enhance: docs = zh_title_enhance(docs) return docs