From 037b5e45021540c8b60638f857479d4254738865 Mon Sep 17 00:00:00 2001 From: yawudede <35065151+yawudede@users.noreply.github.com> Date: Mon, 22 May 2023 10:32:58 +0800 Subject: [PATCH] local_doc_qa.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加精排 --- chains/local_doc_qa.py | 72 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index ecd5e88..6d71e64 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -20,7 +20,45 @@ from models.loader import LoaderCheckPoint import models.shared as shared from agent import bing_search from langchain.docstore.document import Document +from sentence_transformers import SentenceTransformer, CrossEncoder, util +class SemanticSearch: + def __init__(self): + self.use= SentenceTransformer('GanymedeNil_text2vec-large-chinese') + self.fitted = False + + def fit(self, data, batch=100, n_neighbors=10): + self.data = data + self.embeddings = self.get_text_embedding(data, batch=batch) + n_neighbors = min(n_neighbors, len(self.embeddings)) + self.nn = NearestNeighbors(n_neighbors=n_neighbors) + self.nn.fit(self.embeddings) + self.fitted = True + + def __call__(self, text, return_data=True): + inp_emb = self.use.encode([text]) + neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0] + + if return_data: + return [self.data[i] for i in neighbors] + else: + return neighbors + + def get_text_embedding(self, texts, batch=100): + embeddings = [] + for i in range(0, len(texts), batch): + text_batch = texts[i : (i + batch)] + emb_batch = self.use.encode(text_batch) + embeddings.append(emb_batch) + embeddings = np.vstack(embeddings) + return embeddings + +def get_docs_with_score(docs_with_score): + docs = [] + for doc, score in docs_with_score: + doc.metadata["score"] = score + docs.append(doc) + return docs def load_file(filepath, sentence_size=SENTENCE_SIZE): if filepath.lower().endswith(".md"): @@ -262,9 +300,39 @@ class LocalDocQA: vector_store.chunk_conent = self.chunk_conent vector_store.score_threshold = self.score_threshold related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k) + + ###########################################精排 之前faiss检索作为粗排 需要设置config参数 + ####提取文档 + related_docs = get_docs_with_score(related_docs_with_score) + text_batch0=[] + for i in range(len(related_docs)): + cut_txt = " ".join([w for w in list(related_docs[i].page_content)]) + cut_txt =cut_txt.replace(" ", "") + text_batch0.append(cut_txt) + ######文档去重 + text_batch_new=[] + for i in range(len(text_batch0)): + if text_batch0[i] in text_batch_new: + continue + else: + while text_batch_new and text_batch_new[-1] > text_batch0[i] and text_batch_new[-1] in text_batch0[i + 1:]: + text_batch_new.pop() # 弹出栈顶元素 + text_batch_new.append(text_batch0[i]) + text_batch_new0 = "\n".join([doc for doc in text_batch_new]) + ###精排 采用knn和semantic search + recommender = SemanticSearch() + chunks = text_to_chunks(text_batch_new0, start_page=1) + recommender.fit(chunks) + topn_chunks = recommender(query) torch_gc() - prompt = generate_prompt(related_docs_with_score, query) - + #去掉文字中的空格 + topn_chunks0=[] + for i in range(len(topn_chunks)): + cut_txt =topn_chunks[i].replace(" ", "") + topn_chunks0.append(cut_txt) + ############生成prompt + prompt = generate_prompt(topn_chunks0, query) + ######################## for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history, streaming=streaming): resp = answer_result.llm_output["answer"]