适配score_threshold

This commit is contained in:
zqt 2023-08-27 11:21:10 +08:00
parent ead2e26da1
commit 1fa4e906c7
4 changed files with 70 additions and 18 deletions

View File

@ -1,9 +1,13 @@
import operator
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import os import os
import numpy as np
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document from langchain.docstore.document import Document
from sklearn.preprocessing import normalize
from server.db.repository.knowledge_base_repository import ( from server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail, load_kb_from_db, get_kb_detail,
@ -62,7 +66,6 @@ class KBService(ABC):
status = delete_files_from_db(self.kb_name) status = delete_files_from_db(self.kb_name)
return status return status
def drop_kb(self): def drop_kb(self):
""" """
删除知识库 删除知识库
@ -102,7 +105,7 @@ class KBService(ABC):
if os.path.exists(kb_file.filepath): if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file, **kwargs) self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, **kwargs) return self.add_doc(kb_file, **kwargs)
def exist_doc(self, file_name: str): def exist_doc(self, file_name: str):
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
filename=file_name)) filename=file_name))
@ -208,8 +211,9 @@ class KBServiceFactory:
return PGKBService(kb_name, embed_model=embed_model) return PGKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.MILVUS == vector_store_type: elif SupportedVSType.MILVUS == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config return MilvusKBService(kb_name,
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name) return DefaultKBService(kb_name)
@ -217,7 +221,7 @@ class KBServiceFactory:
def get_service_by_name(kb_name: str def get_service_by_name(kb_name: str
) -> KBService: ) -> KBService:
_, vs_type, embed_model = load_kb_from_db(kb_name) _, vs_type, embed_model = load_kb_from_db(kb_name)
if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
vs_type = "faiss" vs_type = "faiss"
return KBServiceFactory.get_service(kb_name, vs_type, embed_model) return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@ -256,7 +260,7 @@ def get_kb_details() -> List[Dict]:
for i, v in enumerate(result.values()): for i, v in enumerate(result.values()):
v['No'] = i + 1 v['No'] = i + 1
data.append(v) data.append(v)
return data return data
@ -292,5 +296,39 @@ def get_kb_doc_details(kb_name: str) -> List[Dict]:
for i, v in enumerate(result.values()): for i, v in enumerate(result.values()):
v['No'] = i + 1 v['No'] = i + 1
data.append(v) data.append(v)
return data return data
class EmbeddingsFunAdapter(Embeddings):
def __init__(self, embeddings: Embeddings):
self.embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return normalize(self.embeddings.embed_documents(texts))
def embed_query(self, text: str) -> List[float]:
query_embed = self.embeddings.embed_query(text)
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return await normalize(self.embeddings.aembed_documents(texts))
async def aembed_query(self, text: str) -> List[float]:
return await normalize(self.embeddings.aembed_query(text))
def score_threshold_process(score_threshold, k, docs):
if score_threshold is not None:
cmp = (
operator.le
)
docs = [
(doc, similarity)
for doc, similarity in docs
if cmp(similarity, score_threshold)
]
return docs[:k]

View File

@ -166,3 +166,10 @@ class FaissKBService(KBService):
return "in_folder" return "in_folder"
else: else:
return False return False
if __name__ == '__main__':
milvusService = FaissKBService("test")
milvusService.add_doc(KnowledgeFile("README.md", "test"))
milvusService.delete_doc(KnowledgeFile("README.md", "test"))
milvusService.do_drop_kb()
print(milvusService.search_docs("如何启动api服务"))

View File

@ -1,12 +1,16 @@
from typing import List from typing import List
import numpy as np
from faiss import normalize_L2
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import Document from langchain.schema import Document
from langchain.vectorstores import Milvus from langchain.vectorstores import Milvus
from sklearn.preprocessing import normalize
from configs.model_config import SCORE_THRESHOLD, kbs_config from configs.model_config import SCORE_THRESHOLD, kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile from server.knowledge_base.utils import KnowledgeFile
@ -36,7 +40,7 @@ class MilvusKBService(KBService):
def _load_milvus(self, embeddings: Embeddings = None): def _load_milvus(self, embeddings: Embeddings = None):
if embeddings is None: if embeddings is None:
embeddings = self._load_embeddings() embeddings = self._load_embeddings()
self.milvus = Milvus(embedding_function=embeddings, self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings),
collection_name=self.kb_name, connection_args=kbs_config.get("milvus")) collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
def do_init(self): def do_init(self):
@ -45,10 +49,9 @@ class MilvusKBService(KBService):
def do_drop_kb(self): def do_drop_kb(self):
self.milvus.col.drop() self.milvus.col.drop()
def do_search(self, query: str, top_k: int,score_threshold: float, embeddings: Embeddings): def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
# todo: support score threshold self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
self._load_milvus(embeddings=embeddings) return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
return self.milvus.similarity_search_with_score(query, top_k)
def add_doc(self, kb_file: KnowledgeFile, **kwargs): def add_doc(self, kb_file: KnowledgeFile, **kwargs):
""" """
@ -83,4 +86,4 @@ if __name__ == '__main__':
milvusService.add_doc(KnowledgeFile("README.md", "test")) milvusService.add_doc(KnowledgeFile("README.md", "test"))
milvusService.delete_doc(KnowledgeFile("README.md", "test")) milvusService.delete_doc(KnowledgeFile("README.md", "test"))
milvusService.do_drop_kb() milvusService.do_drop_kb()
print(milvusService.search_docs("测试")) print(milvusService.search_docs("如何启动api服务"))

View File

@ -3,10 +3,12 @@ from typing import List
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import Document from langchain.schema import Document
from langchain.vectorstores import PGVector from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text from sqlalchemy import text
from configs.model_config import EMBEDDING_DEVICE, kbs_config from configs.model_config import EMBEDDING_DEVICE, kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile from server.knowledge_base.utils import load_embeddings, KnowledgeFile
@ -17,8 +19,9 @@ class PGKBService(KBService):
_embeddings = embeddings _embeddings = embeddings
if _embeddings is None: if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device) _embeddings = load_embeddings(self.embed_model, embedding_device)
self.pg_vector = PGVector(embedding_function=_embeddings, self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings),
collection_name=self.kb_name, collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri")) connection_string=kbs_config.get("pg").get("connection_uri"))
def do_init(self): def do_init(self):
@ -46,7 +49,8 @@ class PGKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
# todo: support score threshold # todo: support score threshold
self._load_pg_vector(embeddings=embeddings) self._load_pg_vector(embeddings=embeddings)
return self.pg_vector.similarity_search_with_score(query, top_k) return score_threshold_process(score_threshold, top_k,
self.pg_vector.similarity_search_with_score(query, top_k))
def add_doc(self, kb_file: KnowledgeFile, **kwargs): def add_doc(self, kb_file: KnowledgeFile, **kwargs):
""" """
@ -83,4 +87,4 @@ if __name__ == '__main__':
pGKBService.add_doc(KnowledgeFile("README.md", "test")) pGKBService.add_doc(KnowledgeFile("README.md", "test"))
pGKBService.delete_doc(KnowledgeFile("README.md", "test")) pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
pGKBService.drop_kb() pGKBService.drop_kb()
print(pGKBService.search_docs("测试")) print(pGKBService.search_docs("如何启动api服务"))