update local_doc_qa.py
This commit is contained in:
parent
6f8da56083
commit
bca32cae83
|
|
@ -20,46 +20,7 @@ from models.loader import LoaderCheckPoint
|
||||||
import models.shared as shared
|
import models.shared as shared
|
||||||
from agent import bing_search
|
from agent import bing_search
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
|
||||||
from sklearn.neighbors import NearestNeighbors
|
|
||||||
|
|
||||||
class SemanticSearch:
|
|
||||||
def __init__(self):
|
|
||||||
self.use= SentenceTransformer('GanymedeNil_text2vec-large-chinese')
|
|
||||||
self.fitted = False
|
|
||||||
|
|
||||||
def fit(self, data, batch=100, n_neighbors=10):
|
|
||||||
self.data = data
|
|
||||||
self.embeddings = self.get_text_embedding(data, batch=batch)
|
|
||||||
n_neighbors = min(n_neighbors, len(self.embeddings))
|
|
||||||
self.nn = NearestNeighbors(n_neighbors=n_neighbors)
|
|
||||||
self.nn.fit(self.embeddings)
|
|
||||||
self.fitted = True
|
|
||||||
|
|
||||||
def __call__(self, text, return_data=True):
|
|
||||||
inp_emb = self.use.encode([text])
|
|
||||||
neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0]
|
|
||||||
|
|
||||||
if return_data:
|
|
||||||
return [self.data[i] for i in neighbors]
|
|
||||||
else:
|
|
||||||
return neighbors
|
|
||||||
|
|
||||||
def get_text_embedding(self, texts, batch=100):
|
|
||||||
embeddings = []
|
|
||||||
for i in range(0, len(texts), batch):
|
|
||||||
text_batch = texts[i : (i + batch)]
|
|
||||||
emb_batch = self.use.encode(text_batch)
|
|
||||||
embeddings.append(emb_batch)
|
|
||||||
embeddings = np.vstack(embeddings)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def get_docs_with_score(docs_with_score):
|
|
||||||
docs = []
|
|
||||||
for doc, score in docs_with_score:
|
|
||||||
doc.metadata["score"] = score
|
|
||||||
docs.append(doc)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
||||||
if filepath.lower().endswith(".md"):
|
if filepath.lower().endswith(".md"):
|
||||||
|
|
@ -301,41 +262,9 @@ class LocalDocQA:
|
||||||
vector_store.chunk_conent = self.chunk_conent
|
vector_store.chunk_conent = self.chunk_conent
|
||||||
vector_store.score_threshold = self.score_threshold
|
vector_store.score_threshold = self.score_threshold
|
||||||
related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k)
|
related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k)
|
||||||
|
|
||||||
###########################################精排 之前faiss检索作为粗排 需要设置model_config参数VECTOR_SEARCH_TOP_K =300
|
|
||||||
###########################################原理:粗排:faiss+semantic search 检索得到大量相关文档,需要设置ECTOR_SEARCH_TOP为300,然后合并文档,重新切分,
|
|
||||||
#############################################利用knn+ semantic search 进行二次检索,输入到prompt
|
|
||||||
####提取文档
|
|
||||||
related_docs = get_docs_with_score(related_docs_with_score)
|
|
||||||
text_batch0=[]
|
|
||||||
for i in range(len(related_docs)):
|
|
||||||
cut_txt = " ".join([w for w in list(related_docs[i].page_content)])
|
|
||||||
cut_txt =cut_txt.replace(" ", "")
|
|
||||||
text_batch0.append(cut_txt)
|
|
||||||
######文档去重
|
|
||||||
text_batch_new=[]
|
|
||||||
for i in range(len(text_batch0)):
|
|
||||||
if text_batch0[i] in text_batch_new:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
while text_batch_new and text_batch_new[-1] > text_batch0[i] and text_batch_new[-1] in text_batch0[i + 1:]:
|
|
||||||
text_batch_new.pop() # 弹出栈顶元素
|
|
||||||
text_batch_new.append(text_batch0[i])
|
|
||||||
text_batch_new0 = "\n".join([doc for doc in text_batch_new])
|
|
||||||
###精排 采用knn和semantic search
|
|
||||||
recommender = SemanticSearch()
|
|
||||||
chunks = text_to_chunks(text_batch_new0, start_page=1)
|
|
||||||
recommender.fit(chunks)
|
|
||||||
topn_chunks = recommender(query)
|
|
||||||
torch_gc()
|
torch_gc()
|
||||||
#去掉文字中的空格
|
prompt = generate_prompt(related_docs_with_score, query)
|
||||||
topn_chunks0=[]
|
|
||||||
for i in range(len(topn_chunks)):
|
|
||||||
cut_txt =topn_chunks[i].replace(" ", "")
|
|
||||||
topn_chunks0.append(cut_txt)
|
|
||||||
############生成prompt
|
|
||||||
prompt = generate_prompt(topn_chunks0, query)
|
|
||||||
########################
|
|
||||||
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
||||||
streaming=streaming):
|
streaming=streaming):
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue