Merge branch 'dev' of https://github.com/imClumsyPanda/langchain-ChatGLM into dev
This commit is contained in:
commit
8742001982
|
|
@ -20,7 +20,46 @@ from models.loader import LoaderCheckPoint
|
||||||
import models.shared as shared
|
import models.shared as shared
|
||||||
from agent import bing_search
|
from agent import bing_search
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||||
|
from sklearn.neighbors import NearestNeighbors
|
||||||
|
|
||||||
|
class SemanticSearch:
|
||||||
|
def __init__(self):
|
||||||
|
self.use= SentenceTransformer('GanymedeNil_text2vec-large-chinese')
|
||||||
|
self.fitted = False
|
||||||
|
|
||||||
|
def fit(self, data, batch=100, n_neighbors=10):
|
||||||
|
self.data = data
|
||||||
|
self.embeddings = self.get_text_embedding(data, batch=batch)
|
||||||
|
n_neighbors = min(n_neighbors, len(self.embeddings))
|
||||||
|
self.nn = NearestNeighbors(n_neighbors=n_neighbors)
|
||||||
|
self.nn.fit(self.embeddings)
|
||||||
|
self.fitted = True
|
||||||
|
|
||||||
|
def __call__(self, text, return_data=True):
|
||||||
|
inp_emb = self.use.encode([text])
|
||||||
|
neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0]
|
||||||
|
|
||||||
|
if return_data:
|
||||||
|
return [self.data[i] for i in neighbors]
|
||||||
|
else:
|
||||||
|
return neighbors
|
||||||
|
|
||||||
|
def get_text_embedding(self, texts, batch=100):
|
||||||
|
embeddings = []
|
||||||
|
for i in range(0, len(texts), batch):
|
||||||
|
text_batch = texts[i : (i + batch)]
|
||||||
|
emb_batch = self.use.encode(text_batch)
|
||||||
|
embeddings.append(emb_batch)
|
||||||
|
embeddings = np.vstack(embeddings)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def get_docs_with_score(docs_with_score):
|
||||||
|
docs = []
|
||||||
|
for doc, score in docs_with_score:
|
||||||
|
doc.metadata["score"] = score
|
||||||
|
docs.append(doc)
|
||||||
|
return docs
|
||||||
|
|
||||||
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
||||||
if filepath.lower().endswith(".md"):
|
if filepath.lower().endswith(".md"):
|
||||||
|
|
@ -262,9 +301,41 @@ class LocalDocQA:
|
||||||
vector_store.chunk_conent = self.chunk_conent
|
vector_store.chunk_conent = self.chunk_conent
|
||||||
vector_store.score_threshold = self.score_threshold
|
vector_store.score_threshold = self.score_threshold
|
||||||
related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k)
|
related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k)
|
||||||
torch_gc()
|
|
||||||
prompt = generate_prompt(related_docs_with_score, query)
|
|
||||||
|
|
||||||
|
###########################################精排 之前faiss检索作为粗排 需要设置model_config参数VECTOR_SEARCH_TOP_K =300
|
||||||
|
###########################################原理:粗排:faiss+semantic search 检索得到大量相关文档,需要设置ECTOR_SEARCH_TOP为300,然后合并文档,重新切分,
|
||||||
|
#############################################利用knn+ semantic search 进行二次检索,输入到prompt
|
||||||
|
####提取文档
|
||||||
|
related_docs = get_docs_with_score(related_docs_with_score)
|
||||||
|
text_batch0=[]
|
||||||
|
for i in range(len(related_docs)):
|
||||||
|
cut_txt = " ".join([w for w in list(related_docs[i].page_content)])
|
||||||
|
cut_txt =cut_txt.replace(" ", "")
|
||||||
|
text_batch0.append(cut_txt)
|
||||||
|
######文档去重
|
||||||
|
text_batch_new=[]
|
||||||
|
for i in range(len(text_batch0)):
|
||||||
|
if text_batch0[i] in text_batch_new:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
while text_batch_new and text_batch_new[-1] > text_batch0[i] and text_batch_new[-1] in text_batch0[i + 1:]:
|
||||||
|
text_batch_new.pop() # 弹出栈顶元素
|
||||||
|
text_batch_new.append(text_batch0[i])
|
||||||
|
text_batch_new0 = "\n".join([doc for doc in text_batch_new])
|
||||||
|
###精排 采用knn和semantic search
|
||||||
|
recommender = SemanticSearch()
|
||||||
|
chunks = text_to_chunks(text_batch_new0, start_page=1)
|
||||||
|
recommender.fit(chunks)
|
||||||
|
topn_chunks = recommender(query)
|
||||||
|
torch_gc()
|
||||||
|
#去掉文字中的空格
|
||||||
|
topn_chunks0=[]
|
||||||
|
for i in range(len(topn_chunks)):
|
||||||
|
cut_txt =topn_chunks[i].replace(" ", "")
|
||||||
|
topn_chunks0.append(cut_txt)
|
||||||
|
############生成prompt
|
||||||
|
prompt = generate_prompt(topn_chunks0, query)
|
||||||
|
########################
|
||||||
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
||||||
streaming=streaming):
|
streaming=streaming):
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue