from typing import List def return_extra(scores,documents,top_k = None): new_docs = [] for idx,score in enumerate(scores): #new_docs.append({"index": idx, "text": documents[idx], "score": 1 / (1 + np.exp(-score))}) new_docs.append({"index": idx, "text": documents[idx], "score": score}) results = [{"document": {"text": doc["text"], "index": doc["index"], "relevance_score": doc["score"]}} for doc in list(sorted(new_docs, key = lambda x: x["score"], reverse=True))] if top_k: return {"results": results[:top_k]} else: return {"results":results} class Singleton(type): def __call__(cls, *args, **kwargs): if not hasattr(cls,'_instence'): cls._instance = super().__call__(*args, **kwargs) return cls._instance class Embedding(metaclass=Singleton): def __init__(self,emb): self.embedding = emb def compute_similarity(self,sentences1: List[str], sentences2: List[str]): if len(sentences1) > 0 and len(sentences2) > 0: embeddings1 = self.encode(sentences1) embeddings2 = self.encode(sentences2) similarity = embeddings1 @ embeddings2.T return similarity else: return None def encode(self,sentences): return self.embedding.encode(sentences) def get_similarity(self,em1,em2): return em1 @ em2.T class M3E_EMB(Embedding): def __init__(self,model_path): from sentence_transformers import SentenceTransformer self.embedding = SentenceTransformer(model_path) super(M3E_EMB,self).__init__(self.embedding) class KNOW_PASS(): def __init__(self,embedding: Embedding): self.embedding = embedding self.knowbase = {} # 临时保存知识库 self.knownames = [] def load_know(self,know_id: str, contents: List[str],drop_dup = True,is_cover = False): if self.knowbase.__contains__(know_id): cur = self.knowbase[know_id] if cur['contents'] == contents: print(f"已经存在{know_id}知识库,包含知识条数:{len(cur['contents'])}") return if is_cover: print(f"清空{know_id}知识库") self.knowbase[know_id] = {} else: contents.extend(cur['contents']) else: self.knownames.append(know_id) if len(self.knownames) > 500: for knowname in self.knownames[:100]: del self.knowbase[knowname] self.knownames = self.knownames[100:] self.knowbase[know_id] = {} if drop_dup: contents = list(set(contents)) em = self.embedding.encode(contents) self.knowbase[know_id]['contents'] = contents self.knowbase[know_id]['embedding'] = em print(f"已更新{know_id}知识库,现在包含知识条数:{len(contents)}") return know_id def get_similarity_pair(self, sentences1: List[str], sentences2: List[str]): similarity = self.embedding.compute_similarity(sentences1,sentences2) similarity = similarity.tolist() res = {"results": similarity} if similarity is not None: return res return None def get_similarity_know(self,query: str,know_id: str,top_k: int = 10): if not self.knowbase.__contains__(know_id): print("当前知识库中不包含{know_id},当前知识库中包括:{self.knowbase.keys()},请确定知识库名称是否正确或者创建知识库") return None em = self.embedding.encode([query]) similarity = self.embedding.get_similarity(em,self.knowbase[know_id]['embedding']) similarity = similarity.tolist() if similarity is None: return None return return_extra(similarity[0],self.knowbase[know_id]['contents'],top_k)