新增特性:reranker对向量召回文本进行重排

This commit is contained in:
hzg0601 2023-12-21 19:05:11 +08:00
parent 5891f94c88
commit 129c765a74
3 changed files with 75 additions and 26 deletions

View File

@ -12,9 +12,9 @@ EMBEDDING_MODEL = "bge-large-zh"
EMBEDDING_DEVICE = "auto" EMBEDDING_DEVICE = "auto"
# 选用的reranker模型 # 选用的reranker模型
RERANKER_MDOEL = "bge-reranker-large" RERANKER_MODEL = "bge-reranker-large"
# 是否启用reranker模型 # 是否启用reranker模型
USE_RERANKER = True USE_RERANKER = False
RERANKER_MAX_LENGTH = 1024 RERANKER_MAX_LENGTH = 1024
# 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置 # 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置
EMBEDDING_KEYWORD_FILE = "keywords.txt" EMBEDDING_KEYWORD_FILE = "keywords.txt"
@ -243,7 +243,7 @@ MODEL_PATH = {
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat", "Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
}, },
"reranker":{ "reranker":{
"bge-reranker-large":BAAI/bge-reranker-large", "bge-reranker-large":"BAAI/bge-reranker-large",
"bge-reranker-base":"BAAI/bge-reranker-base", "bge-reranker-base":"BAAI/bge-reranker-base",
#TODO 增加在线reranker如cohere #TODO 增加在线reranker如cohere
} }

View File

@ -87,14 +87,18 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
# 加入reranker # 加入reranker
if USE_RERANKER: if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large") reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=top_k, reranker_model = LangchainReranker(top_n=top_k,
device=embedding_device(), device=embedding_device(),
max_length=RERANKER_MAX_LENGTH, max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path model_name_or_path=reranker_model_path
) )
print(docs)
docs = reranker_model.compress_documents(documents=docs, docs = reranker_model.compress_documents(documents=docs,
query=query) query=query)
print("---------after rerank------------------")
print(docs)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: # 如果没有找到相关文档使用empty模板 if len(docs) == 0: # 如果没有找到相关文档使用empty模板

View File

@ -1,35 +1,60 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from typing import Any, List, Optional
from sentence_transformers import CrossEncoder from sentence_transformers import CrossEncoder
from typing import Optional, Sequence from typing import Optional, Sequence
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from llama_index.bridge.pydantic import Field,PrivateAttr
class LangchainReranker(BaseDocumentCompressor): class LangchainReranker(BaseDocumentCompressor):
"""Document compressor that uses `Cohere Rerank API`.""" """Document compressor that uses `Cohere Rerank API`."""
model_name_or_path:str = Field()
_model: Any = PrivateAttr()
top_n:int= Field()
device:str=Field()
max_length:int=Field()
batch_size: int = Field()
# show_progress_bar: bool = None
num_workers: int = Field()
# activation_fct = None
# apply_softmax = False
def __init__(self, def __init__(self,
model_name_or_path:str,
top_n:int=3, top_n:int=3,
device:str="cuda", device:str="cuda",
max_length:int=1024, max_length:int=1024,
batch_size: int = 32, batch_size: int = 32,
show_progress_bar: bool = None, # show_progress_bar: bool = None,
num_workers: int = 0, num_workers: int = 0,
activation_fct = None, # activation_fct = None,
apply_softmax = False, # apply_softmax = False,
model_name_or_path:str="BAAI/bge-reraker-large"
): ):
self.top_n=top_n # self.top_n=top_n
self.model_name_or_path=model_name_or_path # self.model_name_or_path=model_name_or_path
# self.device=device
# self.max_length=max_length
# self.batch_size=batch_size
# self.show_progress_bar=show_progress_bar
# self.num_workers=num_workers
# self.activation_fct=activation_fct
# self.apply_softmax=apply_softmax
self.device=device self._model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device)
self.max_length=max_length super().__init__(
self.batch_size=batch_size top_n=top_n,
self.show_progress_bar=show_progress_bar model_name_or_path=model_name_or_path,
self.num_workers=num_workers device=device,
self.activation_fct=activation_fct max_length=max_length,
self.apply_softmax=apply_softmax batch_size=batch_size,
# show_progress_bar=show_progress_bar,
self.model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device) num_workers=num_workers,
# activation_fct=activation_fct,
# apply_softmax=apply_softmax
)
def compress_documents( def compress_documents(
self, self,
@ -53,13 +78,14 @@ class LangchainReranker(BaseDocumentCompressor):
doc_list = list(documents) doc_list = list(documents)
_docs = [d.page_content for d in doc_list] _docs = [d.page_content for d in doc_list]
sentence_pairs = [[query,_doc] for _doc in _docs] sentence_pairs = [[query,_doc] for _doc in _docs]
results = self.model.predict(sentences=sentence_pairs, results = self._model.predict(sentences=sentence_pairs,
batch_size=self.batch_size, batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar, # show_progress_bar=self.show_progress_bar,
num_workers=self.num_workers, num_workers=self.num_workers,
activation_fct=self.activation_fct, # activation_fct=self.activation_fct,
apply_softmax=self.apply_softmax, # apply_softmax=self.apply_softmax,
convert_to_tensor=True) convert_to_tensor=True
)
top_k = self.top_n if self.top_n < len(results) else len(results) top_k = self.top_n if self.top_n < len(results) else len(results)
values, indices = results.topk(top_k) values, indices = results.topk(top_k)
@ -68,4 +94,23 @@ class LangchainReranker(BaseDocumentCompressor):
doc = doc_list[index] doc = doc_list[index]
doc.metadata["relevance_score"] = value doc.metadata["relevance_score"] = value
final_results.append(doc) final_results.append(doc)
return final_results return final_results
if __name__ == "__main__":
from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH)
from server.utils import embedding_device
if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=3,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)