新增特性:reranker对向量召回文本进行重排

This commit is contained in:
hzg0601 2023-12-21 19:05:11 +08:00
parent 5891f94c88
commit 129c765a74
3 changed files with 75 additions and 26 deletions

View File

@ -12,9 +12,9 @@ EMBEDDING_MODEL = "bge-large-zh"
EMBEDDING_DEVICE = "auto"
# 选用的reranker模型
RERANKER_MDOEL = "bge-reranker-large"
RERANKER_MODEL = "bge-reranker-large"
# 是否启用reranker模型
USE_RERANKER = True
USE_RERANKER = False
RERANKER_MAX_LENGTH = 1024
# 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置
EMBEDDING_KEYWORD_FILE = "keywords.txt"
@ -243,7 +243,7 @@ MODEL_PATH = {
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
},
"reranker":{
"bge-reranker-large":BAAI/bge-reranker-large",
"bge-reranker-large":"BAAI/bge-reranker-large",
"bge-reranker-base":"BAAI/bge-reranker-base",
#TODO 增加在线reranker如cohere
}

View File

@ -87,14 +87,18 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
# 加入reranker
if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=top_k,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)
print(docs)
docs = reranker_model.compress_documents(documents=docs,
query=query)
print("---------after rerank------------------")
print(docs)
context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: # 如果没有找到相关文档使用empty模板

View File

@ -1,35 +1,60 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from typing import Any, List, Optional
from sentence_transformers import CrossEncoder
from typing import Optional, Sequence
from langchain_core.documents import Document
from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from llama_index.bridge.pydantic import Field,PrivateAttr
class LangchainReranker(BaseDocumentCompressor):
"""Document compressor that uses `Cohere Rerank API`."""
model_name_or_path:str = Field()
_model: Any = PrivateAttr()
top_n:int= Field()
device:str=Field()
max_length:int=Field()
batch_size: int = Field()
# show_progress_bar: bool = None
num_workers: int = Field()
# activation_fct = None
# apply_softmax = False
def __init__(self,
model_name_or_path:str,
top_n:int=3,
device:str="cuda",
max_length:int=1024,
batch_size: int = 32,
show_progress_bar: bool = None,
# show_progress_bar: bool = None,
num_workers: int = 0,
activation_fct = None,
apply_softmax = False,
model_name_or_path:str="BAAI/bge-reraker-large"
# activation_fct = None,
# apply_softmax = False,
):
self.top_n=top_n
self.model_name_or_path=model_name_or_path
# self.top_n=top_n
# self.model_name_or_path=model_name_or_path
# self.device=device
# self.max_length=max_length
# self.batch_size=batch_size
# self.show_progress_bar=show_progress_bar
# self.num_workers=num_workers
# self.activation_fct=activation_fct
# self.apply_softmax=apply_softmax
self.device=device
self.max_length=max_length
self.batch_size=batch_size
self.show_progress_bar=show_progress_bar
self.num_workers=num_workers
self.activation_fct=activation_fct
self.apply_softmax=apply_softmax
self.model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device)
self._model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device)
super().__init__(
top_n=top_n,
model_name_or_path=model_name_or_path,
device=device,
max_length=max_length,
batch_size=batch_size,
# show_progress_bar=show_progress_bar,
num_workers=num_workers,
# activation_fct=activation_fct,
# apply_softmax=apply_softmax
)
def compress_documents(
self,
@ -53,13 +78,14 @@ class LangchainReranker(BaseDocumentCompressor):
doc_list = list(documents)
_docs = [d.page_content for d in doc_list]
sentence_pairs = [[query,_doc] for _doc in _docs]
results = self.model.predict(sentences=sentence_pairs,
results = self._model.predict(sentences=sentence_pairs,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
# show_progress_bar=self.show_progress_bar,
num_workers=self.num_workers,
activation_fct=self.activation_fct,
apply_softmax=self.apply_softmax,
convert_to_tensor=True)
# activation_fct=self.activation_fct,
# apply_softmax=self.apply_softmax,
convert_to_tensor=True
)
top_k = self.top_n if self.top_n < len(results) else len(results)
values, indices = results.topk(top_k)
@ -69,3 +95,22 @@ class LangchainReranker(BaseDocumentCompressor):
doc.metadata["relevance_score"] = value
final_results.append(doc)
return final_results
if __name__ == "__main__":
from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH)
from server.utils import embedding_device
if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=3,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)