update local_doc_qa.py
This commit is contained in:
parent
6f8da56083
commit
bca32cae83
|
|
@ -20,46 +20,7 @@ from models.loader import LoaderCheckPoint
|
|||
import models.shared as shared
|
||||
from agent import bing_search
|
||||
from langchain.docstore.document import Document
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
||||
class SemanticSearch:
|
||||
def __init__(self):
|
||||
self.use= SentenceTransformer('GanymedeNil_text2vec-large-chinese')
|
||||
self.fitted = False
|
||||
|
||||
def fit(self, data, batch=100, n_neighbors=10):
|
||||
self.data = data
|
||||
self.embeddings = self.get_text_embedding(data, batch=batch)
|
||||
n_neighbors = min(n_neighbors, len(self.embeddings))
|
||||
self.nn = NearestNeighbors(n_neighbors=n_neighbors)
|
||||
self.nn.fit(self.embeddings)
|
||||
self.fitted = True
|
||||
|
||||
def __call__(self, text, return_data=True):
|
||||
inp_emb = self.use.encode([text])
|
||||
neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0]
|
||||
|
||||
if return_data:
|
||||
return [self.data[i] for i in neighbors]
|
||||
else:
|
||||
return neighbors
|
||||
|
||||
def get_text_embedding(self, texts, batch=100):
|
||||
embeddings = []
|
||||
for i in range(0, len(texts), batch):
|
||||
text_batch = texts[i : (i + batch)]
|
||||
emb_batch = self.use.encode(text_batch)
|
||||
embeddings.append(emb_batch)
|
||||
embeddings = np.vstack(embeddings)
|
||||
return embeddings
|
||||
|
||||
def get_docs_with_score(docs_with_score):
|
||||
docs = []
|
||||
for doc, score in docs_with_score:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
||||
if filepath.lower().endswith(".md"):
|
||||
|
|
@ -301,41 +262,9 @@ class LocalDocQA:
|
|||
vector_store.chunk_conent = self.chunk_conent
|
||||
vector_store.score_threshold = self.score_threshold
|
||||
related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k)
|
||||
|
||||
###########################################精排 之前faiss检索作为粗排 需要设置model_config参数VECTOR_SEARCH_TOP_K =300
|
||||
###########################################原理:粗排:faiss+semantic search 检索得到大量相关文档,需要设置ECTOR_SEARCH_TOP为300,然后合并文档,重新切分,
|
||||
#############################################利用knn+ semantic search 进行二次检索,输入到prompt
|
||||
####提取文档
|
||||
related_docs = get_docs_with_score(related_docs_with_score)
|
||||
text_batch0=[]
|
||||
for i in range(len(related_docs)):
|
||||
cut_txt = " ".join([w for w in list(related_docs[i].page_content)])
|
||||
cut_txt =cut_txt.replace(" ", "")
|
||||
text_batch0.append(cut_txt)
|
||||
######文档去重
|
||||
text_batch_new=[]
|
||||
for i in range(len(text_batch0)):
|
||||
if text_batch0[i] in text_batch_new:
|
||||
continue
|
||||
else:
|
||||
while text_batch_new and text_batch_new[-1] > text_batch0[i] and text_batch_new[-1] in text_batch0[i + 1:]:
|
||||
text_batch_new.pop() # 弹出栈顶元素
|
||||
text_batch_new.append(text_batch0[i])
|
||||
text_batch_new0 = "\n".join([doc for doc in text_batch_new])
|
||||
###精排 采用knn和semantic search
|
||||
recommender = SemanticSearch()
|
||||
chunks = text_to_chunks(text_batch_new0, start_page=1)
|
||||
recommender.fit(chunks)
|
||||
topn_chunks = recommender(query)
|
||||
torch_gc()
|
||||
#去掉文字中的空格
|
||||
topn_chunks0=[]
|
||||
for i in range(len(topn_chunks)):
|
||||
cut_txt =topn_chunks[i].replace(" ", "")
|
||||
topn_chunks0.append(cut_txt)
|
||||
############生成prompt
|
||||
prompt = generate_prompt(topn_chunks0, query)
|
||||
########################
|
||||
prompt = generate_prompt(related_docs_with_score, query)
|
||||
|
||||
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
||||
streaming=streaming):
|
||||
resp = answer_result.llm_output["answer"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue