98 lines
3.8 KiB
Python
98 lines
3.8 KiB
Python
from typing import List
|
|
def return_extra(scores,documents,top_k = None):
|
|
new_docs = []
|
|
for idx,score in enumerate(scores):
|
|
#new_docs.append({"index": idx, "text": documents[idx], "score": 1 / (1 + np.exp(-score))})
|
|
new_docs.append({"index": idx, "text": documents[idx], "score": score})
|
|
results = [{"document": {"text": doc["text"], "index": doc["index"], "relevance_score": doc["score"]}} for doc in list(sorted(new_docs, key = lambda x: x["score"], reverse=True))]
|
|
if top_k:
|
|
return {"results": results[:top_k]}
|
|
else:
|
|
return {"results":results}
|
|
|
|
class Singleton(type):
|
|
def __call__(cls, *args, **kwargs):
|
|
if not hasattr(cls,'_instence'):
|
|
cls._instance = super().__call__(*args, **kwargs)
|
|
return cls._instance
|
|
|
|
class Embedding(metaclass=Singleton):
|
|
def __init__(self,emb):
|
|
self.embedding = emb
|
|
|
|
def compute_similarity(self,sentences1: List[str], sentences2: List[str]):
|
|
if len(sentences1) > 0 and len(sentences2) > 0:
|
|
embeddings1 = self.encode(sentences1)
|
|
embeddings2 = self.encode(sentences2)
|
|
similarity = embeddings1 @ embeddings2.T
|
|
return similarity
|
|
else:
|
|
return None
|
|
|
|
def encode(self,sentences):
|
|
return self.embedding.encode(sentences)
|
|
|
|
def get_similarity(self,em1,em2):
|
|
return em1 @ em2.T
|
|
|
|
class M3E_EMB(Embedding):
|
|
def __init__(self,model_path):
|
|
from sentence_transformers import SentenceTransformer
|
|
self.embedding = SentenceTransformer(model_path)
|
|
super(M3E_EMB,self).__init__(self.embedding)
|
|
|
|
|
|
class KNOW_PASS():
|
|
def __init__(self,embedding: Embedding):
|
|
self.embedding = embedding
|
|
self.knowbase = {} # 临时保存知识库
|
|
self.knownames = []
|
|
|
|
def load_know(self,know_id: str, contents: List[str],drop_dup = True,is_cover = False):
|
|
if self.knowbase.__contains__(know_id):
|
|
cur = self.knowbase[know_id]
|
|
if cur['contents'] == contents:
|
|
print(f"已经存在{know_id}知识库,包含知识条数:{len(cur['contents'])}")
|
|
return
|
|
if is_cover:
|
|
print(f"清空{know_id}知识库")
|
|
self.knowbase[know_id] = {}
|
|
else:
|
|
contents.extend(cur['contents'])
|
|
else:
|
|
self.knownames.append(know_id)
|
|
if len(self.knownames) > 500:
|
|
for knowname in self.knownames[:100]:
|
|
del self.knowbase[knowname]
|
|
self.knownames = self.knownames[100:]
|
|
self.knowbase[know_id] = {}
|
|
if drop_dup:
|
|
contents = list(set(contents))
|
|
em = self.embedding.encode(contents)
|
|
self.knowbase[know_id]['contents'] = contents
|
|
self.knowbase[know_id]['embedding'] = em
|
|
print(f"已更新{know_id}知识库,现在包含知识条数:{len(contents)}")
|
|
return know_id
|
|
def get_similarity_pair(self, sentences1: List[str], sentences2: List[str]):
|
|
similarity = self.embedding.compute_similarity(sentences1,sentences2)
|
|
similarity = similarity.tolist()
|
|
res = {"results": similarity}
|
|
if similarity is not None:
|
|
return res
|
|
return None
|
|
|
|
def get_similarity_know(self,query: str,know_id: str,top_k: int = 10):
|
|
if not self.knowbase.__contains__(know_id):
|
|
print("当前知识库中不包含{know_id},当前知识库中包括:{self.knowbase.keys()},请确定知识库名称是否正确或者创建知识库")
|
|
return None
|
|
em = self.embedding.encode([query])
|
|
similarity = self.embedding.get_similarity(em,self.knowbase[know_id]['embedding'])
|
|
similarity = similarity.tolist()
|
|
if similarity is None:
|
|
return None
|
|
return return_extra(similarity[0],self.knowbase[know_id]['contents'],top_k)
|
|
|
|
|
|
|
|
|