add SCORE_THRESHOLD to faiss and milvus kb service
This commit is contained in:
parent
1a112c6908
commit
8df00d0b04
|
|
@ -116,6 +116,9 @@ OVERLAP_SIZE = 50
|
|||
# 知识库匹配向量数量
|
||||
VECTOR_SEARCH_TOP_K = 5
|
||||
|
||||
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
|
||||
SCORE_THRESHOLD = 1
|
||||
|
||||
# 搜索引擎匹配结题数量
|
||||
SEARCH_ENGINE_TOP_K = 5
|
||||
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ def create_app():
|
|||
app.delete("/knowledge_base/delete_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库内的文件"
|
||||
summary="删除知识库内指定文件"
|
||||
)(delete_doc)
|
||||
|
||||
app.post("/knowledge_base/update_doc",
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, EMBEDDING_DEVICE
|
||||
from configs.model_config import (
|
||||
KB_ROOT_PATH,
|
||||
CACHED_VS_NUM,
|
||||
EMBEDDING_MODEL,
|
||||
EMBEDDING_DEVICE,
|
||||
SCORE_THRESHOLD
|
||||
)
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from functools import lru_cache
|
||||
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
||||
|
|
@ -11,7 +17,6 @@ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|||
from typing import List
|
||||
from langchain.docstore.document import Document
|
||||
from server.utils import torch_gc
|
||||
import numpy as np
|
||||
|
||||
|
||||
# make HuggingFaceEmbeddings hashable
|
||||
|
|
@ -36,7 +41,7 @@ def load_vector_store(
|
|||
vs_path = get_vs_path(knowledge_base_name)
|
||||
if embeddings is None:
|
||||
embeddings = load_embeddings(embed_model, embed_device)
|
||||
search_index = FAISS.load_local(vs_path, embeddings)
|
||||
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||
return search_index
|
||||
|
||||
|
||||
|
|
@ -81,7 +86,7 @@ class FaissKBService(KBService):
|
|||
search_index = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||
docs = search_index.similarity_search(query, k=top_k)
|
||||
docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
|
|
@ -89,14 +94,14 @@ class FaissKBService(KBService):
|
|||
embeddings: Embeddings,
|
||||
):
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
else:
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
vector_store = FAISS.from_documents(
|
||||
docs, embeddings) # docs 为Document列表
|
||||
docs, embeddings, normalize_L2=True) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
|
@ -105,7 +110,7 @@ class FaissKBService(KBService):
|
|||
kb_file: KnowledgeFile):
|
||||
embeddings = self._load_embeddings()
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from langchain.embeddings.base import Embeddings
|
|||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Milvus
|
||||
|
||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
||||
from configs.model_config import SCORE_THRESHOLD, kbs_config
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
|
@ -47,7 +47,7 @@ class MilvusKBService(KBService):
|
|||
|
||||
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
|
||||
self._load_milvus(embeddings=embeddings)
|
||||
return self.milvus.similarity_search(query, top_k)
|
||||
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,16 @@
|
|||
import os
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE)
|
||||
from configs.model_config import (
|
||||
embedding_model_dict,
|
||||
KB_ROOT_PATH,
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
ZH_TITLE_ENHANCE
|
||||
)
|
||||
from functools import lru_cache
|
||||
import sys
|
||||
from text_splitter import zh_title_enhance
|
||||
from langchain.document_loaders import UnstructuredFileLoader
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
|
|
@ -74,9 +81,17 @@ class KnowledgeFile:
|
|||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
self.text_splitter_name = None
|
||||
|
||||
def file2text(self, using_zh_title_enhance):
|
||||
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
|
||||
loader = DocumentLoader(self.filepath)
|
||||
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE):
|
||||
print(self.document_loader_name)
|
||||
try:
|
||||
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], "UnstructuredFileLoader")
|
||||
if self.document_loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||
else:
|
||||
loader = DocumentLoader(self.filepath)
|
||||
|
||||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
try:
|
||||
|
|
@ -101,6 +116,7 @@ class KnowledgeFile:
|
|||
)
|
||||
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
print(docs[0])
|
||||
if using_zh_title_enhance:
|
||||
docs = zh_title_enhance(docs)
|
||||
return docs
|
||||
|
|
|
|||
Loading…
Reference in New Issue