Intention/models.py

98 lines
3.8 KiB
Python

from typing import List
def return_extra(scores,documents,top_k = None):
new_docs = []
for idx,score in enumerate(scores):
#new_docs.append({"index": idx, "text": documents[idx], "score": 1 / (1 + np.exp(-score))})
new_docs.append({"index": idx, "text": documents[idx], "score": score})
results = [{"document": {"text": doc["text"], "index": doc["index"], "relevance_score": doc["score"]}} for doc in list(sorted(new_docs, key = lambda x: x["score"], reverse=True))]
if top_k:
return {"results": results[:top_k]}
else:
return {"results":results}
class Singleton(type):
def __call__(cls, *args, **kwargs):
if not hasattr(cls,'_instence'):
cls._instance = super().__call__(*args, **kwargs)
return cls._instance
class Embedding(metaclass=Singleton):
def __init__(self,emb):
self.embedding = emb
def compute_similarity(self,sentences1: List[str], sentences2: List[str]):
if len(sentences1) > 0 and len(sentences2) > 0:
embeddings1 = self.encode(sentences1)
embeddings2 = self.encode(sentences2)
similarity = embeddings1 @ embeddings2.T
return similarity
else:
return None
def encode(self,sentences):
return self.embedding.encode(sentences)
def get_similarity(self,em1,em2):
return em1 @ em2.T
class M3E_EMB(Embedding):
def __init__(self,model_path):
from sentence_transformers import SentenceTransformer
self.embedding = SentenceTransformer(model_path)
super(M3E_EMB,self).__init__(self.embedding)
class KNOW_PASS():
def __init__(self,embedding: Embedding):
self.embedding = embedding
self.knowbase = {} # 临时保存知识库
self.knownames = []
def load_know(self,know_id: str, contents: List[str],drop_dup = True,is_cover = False):
if self.knowbase.__contains__(know_id):
cur = self.knowbase[know_id]
if cur['contents'] == contents:
print(f"已经存在{know_id}知识库,包含知识条数:{len(cur['contents'])}")
return
if is_cover:
print(f"清空{know_id}知识库")
self.knowbase[know_id] = {}
else:
contents.extend(cur['contents'])
else:
self.knownames.append(know_id)
if len(self.knownames) > 500:
for knowname in self.knownames[:100]:
del self.knowbase[knowname]
self.knownames = self.knownames[100:]
self.knowbase[know_id] = {}
if drop_dup:
contents = list(set(contents))
em = self.embedding.encode(contents)
self.knowbase[know_id]['contents'] = contents
self.knowbase[know_id]['embedding'] = em
print(f"已更新{know_id}知识库,现在包含知识条数:{len(contents)}")
return know_id
def get_similarity_pair(self, sentences1: List[str], sentences2: List[str]):
similarity = self.embedding.compute_similarity(sentences1,sentences2)
similarity = similarity.tolist()
res = {"results": similarity}
if similarity is not None:
return res
return None
def get_similarity_know(self,query: str,know_id: str,top_k: int = 10):
if not self.knowbase.__contains__(know_id):
print("当前知识库中不包含{know_id},当前知识库中包括:{self.knowbase.keys()},请确定知识库名称是否正确或者创建知识库")
return None
em = self.embedding.encode([query])
similarity = self.embedding.get_similarity(em,self.knowbase[know_id]['embedding'])
similarity = similarity.tolist()
if similarity is None:
return None
return return_extra(similarity[0],self.knowbase[know_id]['contents'],top_k)