适配score_threshold
This commit is contained in:
parent
ead2e26da1
commit
1fa4e906c7
|
|
@ -1,9 +1,13 @@
|
|||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
from sklearn.preprocessing import normalize
|
||||
|
||||
from server.db.repository.knowledge_base_repository import (
|
||||
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||
load_kb_from_db, get_kb_detail,
|
||||
|
|
@ -62,7 +66,6 @@ class KBService(ABC):
|
|||
status = delete_files_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
|
||||
def drop_kb(self):
|
||||
"""
|
||||
删除知识库
|
||||
|
|
@ -208,7 +211,8 @@ class KBServiceFactory:
|
|||
return PGKBService(kb_name, embed_model=embed_model)
|
||||
elif SupportedVSType.MILVUS == vector_store_type:
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
||||
return MilvusKBService(kb_name,
|
||||
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
||||
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||
return DefaultKBService(kb_name)
|
||||
|
|
@ -294,3 +298,37 @@ def get_kb_doc_details(kb_name: str) -> List[Dict]:
|
|||
data.append(v)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingsFunAdapter(Embeddings):
|
||||
|
||||
def __init__(self, embeddings: Embeddings):
|
||||
self.embeddings = embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return normalize(self.embeddings.embed_documents(texts))
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
query_embed = self.embeddings.embed_query(text)
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = normalize(query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return await normalize(self.embeddings.aembed_documents(texts))
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return await normalize(self.embeddings.aembed_query(text))
|
||||
|
||||
|
||||
def score_threshold_process(score_threshold, k, docs):
|
||||
if score_threshold is not None:
|
||||
cmp = (
|
||||
operator.le
|
||||
)
|
||||
docs = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs
|
||||
if cmp(similarity, score_threshold)
|
||||
]
|
||||
return docs[:k]
|
||||
|
|
|
|||
|
|
@ -166,3 +166,10 @@ class FaissKBService(KBService):
|
|||
return "in_folder"
|
||||
else:
|
||||
return False
|
||||
if __name__ == '__main__':
|
||||
|
||||
milvusService = FaissKBService("test")
|
||||
milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
milvusService.do_drop_kb()
|
||||
print(milvusService.search_docs("如何启动api服务"))
|
||||
|
|
@ -1,12 +1,16 @@
|
|||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from faiss import normalize_L2
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Milvus
|
||||
from sklearn.preprocessing import normalize
|
||||
|
||||
from configs.model_config import SCORE_THRESHOLD, kbs_config
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
|
||||
|
|
@ -36,7 +40,7 @@ class MilvusKBService(KBService):
|
|||
def _load_milvus(self, embeddings: Embeddings = None):
|
||||
if embeddings is None:
|
||||
embeddings = self._load_embeddings()
|
||||
self.milvus = Milvus(embedding_function=embeddings,
|
||||
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings),
|
||||
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
|
||||
|
||||
def do_init(self):
|
||||
|
|
@ -45,10 +49,9 @@ class MilvusKBService(KBService):
|
|||
def do_drop_kb(self):
|
||||
self.milvus.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int,score_threshold: float, embeddings: Embeddings):
|
||||
# todo: support score threshold
|
||||
self._load_milvus(embeddings=embeddings)
|
||||
return self.milvus.similarity_search_with_score(query, top_k)
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
|
||||
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
"""
|
||||
|
|
@ -83,4 +86,4 @@ if __name__ == '__main__':
|
|||
milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
milvusService.do_drop_kb()
|
||||
print(milvusService.search_docs("测试"))
|
||||
print(milvusService.search_docs("如何启动api服务"))
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ from typing import List
|
|||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import PGVector
|
||||
from langchain.vectorstores.pgvector import DistanceStrategy
|
||||
from sqlalchemy import text
|
||||
|
||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
||||
|
||||
|
||||
|
|
@ -17,8 +19,9 @@ class PGKBService(KBService):
|
|||
_embeddings = embeddings
|
||||
if _embeddings is None:
|
||||
_embeddings = load_embeddings(self.embed_model, embedding_device)
|
||||
self.pg_vector = PGVector(embedding_function=_embeddings,
|
||||
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings),
|
||||
collection_name=self.kb_name,
|
||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||
|
||||
def do_init(self):
|
||||
|
|
@ -46,7 +49,8 @@ class PGKBService(KBService):
|
|||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||
# todo: support score threshold
|
||||
self._load_pg_vector(embeddings=embeddings)
|
||||
return self.pg_vector.similarity_search_with_score(query, top_k)
|
||||
return score_threshold_process(score_threshold, top_k,
|
||||
self.pg_vector.similarity_search_with_score(query, top_k))
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
"""
|
||||
|
|
@ -83,4 +87,4 @@ if __name__ == '__main__':
|
|||
pGKBService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
pGKBService.drop_kb()
|
||||
print(pGKBService.search_docs("测试"))
|
||||
print(pGKBService.search_docs("如何启动api服务"))
|
||||
|
|
|
|||
Loading…
Reference in New Issue