适配score_threshold

This commit is contained in:
zqt 2023-08-27 11:21:10 +08:00
parent ead2e26da1
commit 1fa4e906c7
4 changed files with 70 additions and 18 deletions

View File

@ -1,9 +1,13 @@
import operator
from abc import ABC, abstractmethod
import os
import numpy as np
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
from sklearn.preprocessing import normalize
from server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail,
@ -62,7 +66,6 @@ class KBService(ABC):
status = delete_files_from_db(self.kb_name)
return status
def drop_kb(self):
"""
删除知识库
@ -208,7 +211,8 @@ class KBServiceFactory:
return PGKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.MILVUS == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
return MilvusKBService(kb_name,
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)
@ -294,3 +298,37 @@ def get_kb_doc_details(kb_name: str) -> List[Dict]:
data.append(v)
return data
class EmbeddingsFunAdapter(Embeddings):
def __init__(self, embeddings: Embeddings):
self.embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return normalize(self.embeddings.embed_documents(texts))
def embed_query(self, text: str) -> List[float]:
query_embed = self.embeddings.embed_query(text)
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return await normalize(self.embeddings.aembed_documents(texts))
async def aembed_query(self, text: str) -> List[float]:
return await normalize(self.embeddings.aembed_query(text))
def score_threshold_process(score_threshold, k, docs):
if score_threshold is not None:
cmp = (
operator.le
)
docs = [
(doc, similarity)
for doc, similarity in docs
if cmp(similarity, score_threshold)
]
return docs[:k]

View File

@ -166,3 +166,10 @@ class FaissKBService(KBService):
return "in_folder"
else:
return False
if __name__ == '__main__':
milvusService = FaissKBService("test")
milvusService.add_doc(KnowledgeFile("README.md", "test"))
milvusService.delete_doc(KnowledgeFile("README.md", "test"))
milvusService.do_drop_kb()
print(milvusService.search_docs("如何启动api服务"))

View File

@ -1,12 +1,16 @@
from typing import List
import numpy as np
from faiss import normalize_L2
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import Milvus
from sklearn.preprocessing import normalize
from configs.model_config import SCORE_THRESHOLD, kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
@ -36,7 +40,7 @@ class MilvusKBService(KBService):
def _load_milvus(self, embeddings: Embeddings = None):
if embeddings is None:
embeddings = self._load_embeddings()
self.milvus = Milvus(embedding_function=embeddings,
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings),
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
def do_init(self):
@ -46,9 +50,8 @@ class MilvusKBService(KBService):
self.milvus.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
# todo: support score threshold
self._load_milvus(embeddings=embeddings)
return self.milvus.similarity_search_with_score(query, top_k)
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
@ -83,4 +86,4 @@ if __name__ == '__main__':
milvusService.add_doc(KnowledgeFile("README.md", "test"))
milvusService.delete_doc(KnowledgeFile("README.md", "test"))
milvusService.do_drop_kb()
print(milvusService.search_docs("测试"))
print(milvusService.search_docs("如何启动api服务"))

View File

@ -3,10 +3,12 @@ from typing import List
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text
from configs.model_config import EMBEDDING_DEVICE, kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
@ -17,8 +19,9 @@ class PGKBService(KBService):
_embeddings = embeddings
if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device)
self.pg_vector = PGVector(embedding_function=_embeddings,
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings),
collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri"))
def do_init(self):
@ -46,7 +49,8 @@ class PGKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
# todo: support score threshold
self._load_pg_vector(embeddings=embeddings)
return self.pg_vector.similarity_search_with_score(query, top_k)
return score_threshold_process(score_threshold, top_k,
self.pg_vector.similarity_search_with_score(query, top_k))
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
@ -83,4 +87,4 @@ if __name__ == '__main__':
pGKBService.add_doc(KnowledgeFile("README.md", "test"))
pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
pGKBService.drop_kb()
print(pGKBService.search_docs("测试"))
print(pGKBService.search_docs("如何启动api服务"))