add SCORE_THRESHOLD to faiss and milvus kb service

This commit is contained in:
imClumsyPanda 2023-08-10 00:36:51 +08:00
parent 1a112c6908
commit 8df00d0b04
5 changed files with 38 additions and 14 deletions

View File

@ -116,6 +116,9 @@ OVERLAP_SIZE = 50
# 知识库匹配向量数量
VECTOR_SEARCH_TOP_K = 5
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
SCORE_THRESHOLD = 1
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 5

View File

@ -91,7 +91,7 @@ def create_app():
app.delete("/knowledge_base/delete_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内文件"
summary="删除知识库内指定文件"
)(delete_doc)
app.post("/knowledge_base/update_doc",

View File

@ -1,7 +1,13 @@
import os
import shutil
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, EMBEDDING_DEVICE
from configs.model_config import (
KB_ROOT_PATH,
CACHED_VS_NUM,
EMBEDDING_MODEL,
EMBEDDING_DEVICE,
SCORE_THRESHOLD
)
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
@ -11,7 +17,6 @@ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from typing import List
from langchain.docstore.document import Document
from server.utils import torch_gc
import numpy as np
# make HuggingFaceEmbeddings hashable
@ -36,7 +41,7 @@ def load_vector_store(
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
embeddings = load_embeddings(embed_model, embed_device)
search_index = FAISS.load_local(vs_path, embeddings)
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
return search_index
@ -81,7 +86,7 @@ class FaissKBService(KBService):
search_index = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
docs = search_index.similarity_search(query, k=top_k)
docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD)
return docs
def do_add_doc(self,
@ -89,14 +94,14 @@ class FaissKBService(KBService):
embeddings: Embeddings,
):
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings)
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
vector_store.add_documents(docs)
torch_gc()
else:
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
vector_store = FAISS.from_documents(
docs, embeddings) # docs 为Document列表
docs, embeddings, normalize_L2=True) # docs 为Document列表
torch_gc()
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
@ -105,7 +110,7 @@ class FaissKBService(KBService):
kb_file: KnowledgeFile):
embeddings = self._load_embeddings()
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings)
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None

View File

@ -4,7 +4,7 @@ from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import Milvus
from configs.model_config import EMBEDDING_DEVICE, kbs_config
from configs.model_config import SCORE_THRESHOLD, kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.utils import KnowledgeFile
@ -47,7 +47,7 @@ class MilvusKBService(KBService):
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
self._load_milvus(embeddings=embeddings)
return self.milvus.similarity_search(query, top_k)
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)
def add_doc(self, kb_file: KnowledgeFile):
"""

View File

@ -1,9 +1,16 @@
import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE)
from configs.model_config import (
embedding_model_dict,
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE
)
from functools import lru_cache
import sys
from text_splitter import zh_title_enhance
from langchain.document_loaders import UnstructuredFileLoader
def validate_kb_name(knowledge_base_id: str) -> bool:
@ -74,9 +81,17 @@ class KnowledgeFile:
# TODO: 增加依据文件格式匹配text_splitter
self.text_splitter_name = None
def file2text(self, using_zh_title_enhance):
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
loader = DocumentLoader(self.filepath)
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE):
print(self.document_loader_name)
try:
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
except Exception as e:
print(e)
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], "UnstructuredFileLoader")
if self.document_loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
else:
loader = DocumentLoader(self.filepath)
# TODO: 增加依据文件格式匹配text_splitter
try:
@ -101,6 +116,7 @@ class KnowledgeFile:
)
docs = loader.load_and_split(text_splitter)
print(docs[0])
if using_zh_title_enhance:
docs = zh_title_enhance(docs)
return docs