diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 8d1de48..6fa259d 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -1,9 +1,13 @@ +import operator from abc import ABC, abstractmethod import os +import numpy as np from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document +from sklearn.preprocessing import normalize + from server.db.repository.knowledge_base_repository import ( add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db, get_kb_detail, @@ -62,7 +66,6 @@ class KBService(ABC): status = delete_files_from_db(self.kb_name) return status - def drop_kb(self): """ 删除知识库 @@ -102,7 +105,7 @@ class KBService(ABC): if os.path.exists(kb_file.filepath): self.delete_doc(kb_file, **kwargs) return self.add_doc(kb_file, **kwargs) - + def exist_doc(self, file_name: str): return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, filename=file_name)) @@ -208,8 +211,9 @@ class KBServiceFactory: return PGKBService(kb_name, embed_model=embed_model) elif SupportedVSType.MILVUS == vector_store_type: from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService - return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config - elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. + return MilvusKBService(kb_name, + embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config + elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. from server.knowledge_base.kb_service.default_kb_service import DefaultKBService return DefaultKBService(kb_name) @@ -217,7 +221,7 @@ class KBServiceFactory: def get_service_by_name(kb_name: str ) -> KBService: _, vs_type, embed_model = load_kb_from_db(kb_name) - if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db + if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db vs_type = "faiss" return KBServiceFactory.get_service(kb_name, vs_type, embed_model) @@ -256,7 +260,7 @@ def get_kb_details() -> List[Dict]: for i, v in enumerate(result.values()): v['No'] = i + 1 data.append(v) - + return data @@ -292,5 +296,39 @@ def get_kb_doc_details(kb_name: str) -> List[Dict]: for i, v in enumerate(result.values()): v['No'] = i + 1 data.append(v) - + return data + + +class EmbeddingsFunAdapter(Embeddings): + + def __init__(self, embeddings: Embeddings): + self.embeddings = embeddings + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return normalize(self.embeddings.embed_documents(texts)) + + def embed_query(self, text: str) -> List[float]: + query_embed = self.embeddings.embed_query(text) + query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 + normalized_query_embed = normalize(query_embed_2d) + return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + return await normalize(self.embeddings.aembed_documents(texts)) + + async def aembed_query(self, text: str) -> List[float]: + return await normalize(self.embeddings.aembed_query(text)) + + +def score_threshold_process(score_threshold, k, docs): + if score_threshold is not None: + cmp = ( + operator.le + ) + docs = [ + (doc, similarity) + for doc, similarity in docs + if cmp(similarity, score_threshold) + ] + return docs[:k] diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index b3f5439..40c67b5 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -166,3 +166,10 @@ class FaissKBService(KBService): return "in_folder" else: return False +if __name__ == '__main__': + + milvusService = FaissKBService("test") + milvusService.add_doc(KnowledgeFile("README.md", "test")) + milvusService.delete_doc(KnowledgeFile("README.md", "test")) + milvusService.do_drop_kb() + print(milvusService.search_docs("如何启动api服务")) \ No newline at end of file diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index f2c798c..296ae44 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -1,12 +1,16 @@ from typing import List +import numpy as np +from faiss import normalize_L2 from langchain.embeddings.base import Embeddings from langchain.schema import Document from langchain.vectorstores import Milvus +from sklearn.preprocessing import normalize from configs.model_config import SCORE_THRESHOLD, kbs_config -from server.knowledge_base.kb_service.base import KBService, SupportedVSType +from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \ + score_threshold_process from server.knowledge_base.utils import KnowledgeFile @@ -36,7 +40,7 @@ class MilvusKBService(KBService): def _load_milvus(self, embeddings: Embeddings = None): if embeddings is None: embeddings = self._load_embeddings() - self.milvus = Milvus(embedding_function=embeddings, + self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings), collection_name=self.kb_name, connection_args=kbs_config.get("milvus")) def do_init(self): @@ -45,10 +49,9 @@ class MilvusKBService(KBService): def do_drop_kb(self): self.milvus.col.drop() - def do_search(self, query: str, top_k: int,score_threshold: float, embeddings: Embeddings): - # todo: support score threshold - self._load_milvus(embeddings=embeddings) - return self.milvus.similarity_search_with_score(query, top_k) + def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): + self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings)) + return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) def add_doc(self, kb_file: KnowledgeFile, **kwargs): """ @@ -83,4 +86,4 @@ if __name__ == '__main__': milvusService.add_doc(KnowledgeFile("README.md", "test")) milvusService.delete_doc(KnowledgeFile("README.md", "test")) milvusService.do_drop_kb() - print(milvusService.search_docs("测试")) + print(milvusService.search_docs("如何启动api服务")) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 6876bd8..31cc908 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -3,10 +3,12 @@ from typing import List from langchain.embeddings.base import Embeddings from langchain.schema import Document from langchain.vectorstores import PGVector +from langchain.vectorstores.pgvector import DistanceStrategy from sqlalchemy import text from configs.model_config import EMBEDDING_DEVICE, kbs_config -from server.knowledge_base.kb_service.base import SupportedVSType, KBService +from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ + score_threshold_process from server.knowledge_base.utils import load_embeddings, KnowledgeFile @@ -17,8 +19,9 @@ class PGKBService(KBService): _embeddings = embeddings if _embeddings is None: _embeddings = load_embeddings(self.embed_model, embedding_device) - self.pg_vector = PGVector(embedding_function=_embeddings, + self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings), collection_name=self.kb_name, + distance_strategy=DistanceStrategy.EUCLIDEAN, connection_string=kbs_config.get("pg").get("connection_uri")) def do_init(self): @@ -46,7 +49,8 @@ class PGKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): # todo: support score threshold self._load_pg_vector(embeddings=embeddings) - return self.pg_vector.similarity_search_with_score(query, top_k) + return score_threshold_process(score_threshold, top_k, + self.pg_vector.similarity_search_with_score(query, top_k)) def add_doc(self, kb_file: KnowledgeFile, **kwargs): """ @@ -83,4 +87,4 @@ if __name__ == '__main__': pGKBService.add_doc(KnowledgeFile("README.md", "test")) pGKBService.delete_doc(KnowledgeFile("README.md", "test")) pGKBService.drop_kb() - print(pGKBService.search_docs("测试")) + print(pGKBService.search_docs("如何启动api服务"))