add SCORE_THRESHOLD to faiss and milvus kb service
This commit is contained in:
parent
1a112c6908
commit
8df00d0b04
|
|
@ -116,6 +116,9 @@ OVERLAP_SIZE = 50
|
||||||
# 知识库匹配向量数量
|
# 知识库匹配向量数量
|
||||||
VECTOR_SEARCH_TOP_K = 5
|
VECTOR_SEARCH_TOP_K = 5
|
||||||
|
|
||||||
|
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
|
||||||
|
SCORE_THRESHOLD = 1
|
||||||
|
|
||||||
# 搜索引擎匹配结题数量
|
# 搜索引擎匹配结题数量
|
||||||
SEARCH_ENGINE_TOP_K = 5
|
SEARCH_ENGINE_TOP_K = 5
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ def create_app():
|
||||||
app.delete("/knowledge_base/delete_doc",
|
app.delete("/knowledge_base/delete_doc",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
summary="删除知识库内的文件"
|
summary="删除知识库内指定文件"
|
||||||
)(delete_doc)
|
)(delete_doc)
|
||||||
|
|
||||||
app.post("/knowledge_base/update_doc",
|
app.post("/knowledge_base/update_doc",
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,13 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, EMBEDDING_DEVICE
|
from configs.model_config import (
|
||||||
|
KB_ROOT_PATH,
|
||||||
|
CACHED_VS_NUM,
|
||||||
|
EMBEDDING_MODEL,
|
||||||
|
EMBEDDING_DEVICE,
|
||||||
|
SCORE_THRESHOLD
|
||||||
|
)
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
||||||
|
|
@ -11,7 +17,6 @@ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from typing import List
|
from typing import List
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from server.utils import torch_gc
|
from server.utils import torch_gc
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
# make HuggingFaceEmbeddings hashable
|
# make HuggingFaceEmbeddings hashable
|
||||||
|
|
@ -36,7 +41,7 @@ def load_vector_store(
|
||||||
vs_path = get_vs_path(knowledge_base_name)
|
vs_path = get_vs_path(knowledge_base_name)
|
||||||
if embeddings is None:
|
if embeddings is None:
|
||||||
embeddings = load_embeddings(embed_model, embed_device)
|
embeddings = load_embeddings(embed_model, embed_device)
|
||||||
search_index = FAISS.load_local(vs_path, embeddings)
|
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||||
return search_index
|
return search_index
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -81,7 +86,7 @@ class FaissKBService(KBService):
|
||||||
search_index = load_vector_store(self.kb_name,
|
search_index = load_vector_store(self.kb_name,
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
|
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||||
docs = search_index.similarity_search(query, k=top_k)
|
docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self,
|
def do_add_doc(self,
|
||||||
|
|
@ -89,14 +94,14 @@ class FaissKBService(KBService):
|
||||||
embeddings: Embeddings,
|
embeddings: Embeddings,
|
||||||
):
|
):
|
||||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
|
||||||
vector_store.add_documents(docs)
|
vector_store.add_documents(docs)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
else:
|
else:
|
||||||
if not os.path.exists(self.vs_path):
|
if not os.path.exists(self.vs_path):
|
||||||
os.makedirs(self.vs_path)
|
os.makedirs(self.vs_path)
|
||||||
vector_store = FAISS.from_documents(
|
vector_store = FAISS.from_documents(
|
||||||
docs, embeddings) # docs 为Document列表
|
docs, embeddings, normalize_L2=True) # docs 为Document列表
|
||||||
torch_gc()
|
torch_gc()
|
||||||
vector_store.save_local(self.vs_path)
|
vector_store.save_local(self.vs_path)
|
||||||
refresh_vs_cache(self.kb_name)
|
refresh_vs_cache(self.kb_name)
|
||||||
|
|
@ -105,7 +110,7 @@ class FaissKBService(KBService):
|
||||||
kb_file: KnowledgeFile):
|
kb_file: KnowledgeFile):
|
||||||
embeddings = self._load_embeddings()
|
embeddings = self._load_embeddings()
|
||||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
|
||||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||||
if len(ids) == 0:
|
if len(ids) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from langchain.vectorstores import Milvus
|
from langchain.vectorstores import Milvus
|
||||||
|
|
||||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
from configs.model_config import SCORE_THRESHOLD, kbs_config
|
||||||
|
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from server.knowledge_base.utils import KnowledgeFile
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
|
|
@ -47,7 +47,7 @@ class MilvusKBService(KBService):
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
|
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
|
||||||
self._load_milvus(embeddings=embeddings)
|
self._load_milvus(embeddings=embeddings)
|
||||||
return self.milvus.similarity_search(query, top_k)
|
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)
|
||||||
|
|
||||||
def add_doc(self, kb_file: KnowledgeFile):
|
def add_doc(self, kb_file: KnowledgeFile):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,16 @@
|
||||||
import os
|
import os
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE)
|
from configs.model_config import (
|
||||||
|
embedding_model_dict,
|
||||||
|
KB_ROOT_PATH,
|
||||||
|
CHUNK_SIZE,
|
||||||
|
OVERLAP_SIZE,
|
||||||
|
ZH_TITLE_ENHANCE
|
||||||
|
)
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import sys
|
import sys
|
||||||
from text_splitter import zh_title_enhance
|
from text_splitter import zh_title_enhance
|
||||||
|
from langchain.document_loaders import UnstructuredFileLoader
|
||||||
|
|
||||||
|
|
||||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||||
|
|
@ -74,8 +81,16 @@ class KnowledgeFile:
|
||||||
# TODO: 增加依据文件格式匹配text_splitter
|
# TODO: 增加依据文件格式匹配text_splitter
|
||||||
self.text_splitter_name = None
|
self.text_splitter_name = None
|
||||||
|
|
||||||
def file2text(self, using_zh_title_enhance):
|
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE):
|
||||||
|
print(self.document_loader_name)
|
||||||
|
try:
|
||||||
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
|
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], "UnstructuredFileLoader")
|
||||||
|
if self.document_loader_name == "UnstructuredFileLoader":
|
||||||
|
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||||
|
else:
|
||||||
loader = DocumentLoader(self.filepath)
|
loader = DocumentLoader(self.filepath)
|
||||||
|
|
||||||
# TODO: 增加依据文件格式匹配text_splitter
|
# TODO: 增加依据文件格式匹配text_splitter
|
||||||
|
|
@ -101,6 +116,7 @@ class KnowledgeFile:
|
||||||
)
|
)
|
||||||
|
|
||||||
docs = loader.load_and_split(text_splitter)
|
docs = loader.load_and_split(text_splitter)
|
||||||
|
print(docs[0])
|
||||||
if using_zh_title_enhance:
|
if using_zh_title_enhance:
|
||||||
docs = zh_title_enhance(docs)
|
docs = zh_title_enhance(docs)
|
||||||
return docs
|
return docs
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue