新增特性:reranker对向量召回文本进行重排
This commit is contained in:
parent
5891f94c88
commit
129c765a74
|
|
@ -12,9 +12,9 @@ EMBEDDING_MODEL = "bge-large-zh"
|
||||||
EMBEDDING_DEVICE = "auto"
|
EMBEDDING_DEVICE = "auto"
|
||||||
|
|
||||||
# 选用的reranker模型
|
# 选用的reranker模型
|
||||||
RERANKER_MDOEL = "bge-reranker-large"
|
RERANKER_MODEL = "bge-reranker-large"
|
||||||
# 是否启用reranker模型
|
# 是否启用reranker模型
|
||||||
USE_RERANKER = True
|
USE_RERANKER = False
|
||||||
RERANKER_MAX_LENGTH = 1024
|
RERANKER_MAX_LENGTH = 1024
|
||||||
# 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置
|
# 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置
|
||||||
EMBEDDING_KEYWORD_FILE = "keywords.txt"
|
EMBEDDING_KEYWORD_FILE = "keywords.txt"
|
||||||
|
|
@ -243,7 +243,7 @@ MODEL_PATH = {
|
||||||
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
|
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
|
||||||
},
|
},
|
||||||
"reranker":{
|
"reranker":{
|
||||||
"bge-reranker-large":BAAI/bge-reranker-large",
|
"bge-reranker-large":"BAAI/bge-reranker-large",
|
||||||
"bge-reranker-base":"BAAI/bge-reranker-base",
|
"bge-reranker-base":"BAAI/bge-reranker-base",
|
||||||
#TODO 增加在线reranker,如cohere
|
#TODO 增加在线reranker,如cohere
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -87,14 +87,18 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
# 加入reranker
|
# 加入reranker
|
||||||
if USE_RERANKER:
|
if USE_RERANKER:
|
||||||
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
|
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
|
||||||
|
print("-----------------model path------------------")
|
||||||
|
print(reranker_model_path)
|
||||||
reranker_model = LangchainReranker(top_n=top_k,
|
reranker_model = LangchainReranker(top_n=top_k,
|
||||||
device=embedding_device(),
|
device=embedding_device(),
|
||||||
max_length=RERANKER_MAX_LENGTH,
|
max_length=RERANKER_MAX_LENGTH,
|
||||||
model_name_or_path=reranker_model_path
|
model_name_or_path=reranker_model_path
|
||||||
)
|
)
|
||||||
|
print(docs)
|
||||||
docs = reranker_model.compress_documents(documents=docs,
|
docs = reranker_model.compress_documents(documents=docs,
|
||||||
query=query)
|
query=query)
|
||||||
|
print("---------after rerank------------------")
|
||||||
|
print(docs)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
|
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
|
||||||
|
|
|
||||||
|
|
@ -1,35 +1,60 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||||
|
from typing import Any, List, Optional
|
||||||
from sentence_transformers import CrossEncoder
|
from sentence_transformers import CrossEncoder
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||||
|
from llama_index.bridge.pydantic import Field,PrivateAttr
|
||||||
|
|
||||||
class LangchainReranker(BaseDocumentCompressor):
|
class LangchainReranker(BaseDocumentCompressor):
|
||||||
"""Document compressor that uses `Cohere Rerank API`."""
|
"""Document compressor that uses `Cohere Rerank API`."""
|
||||||
|
model_name_or_path:str = Field()
|
||||||
|
_model: Any = PrivateAttr()
|
||||||
|
top_n:int= Field()
|
||||||
|
device:str=Field()
|
||||||
|
max_length:int=Field()
|
||||||
|
batch_size: int = Field()
|
||||||
|
# show_progress_bar: bool = None
|
||||||
|
num_workers: int = Field()
|
||||||
|
# activation_fct = None
|
||||||
|
# apply_softmax = False
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
model_name_or_path:str,
|
||||||
top_n:int=3,
|
top_n:int=3,
|
||||||
device:str="cuda",
|
device:str="cuda",
|
||||||
max_length:int=1024,
|
max_length:int=1024,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
show_progress_bar: bool = None,
|
# show_progress_bar: bool = None,
|
||||||
num_workers: int = 0,
|
num_workers: int = 0,
|
||||||
activation_fct = None,
|
# activation_fct = None,
|
||||||
apply_softmax = False,
|
# apply_softmax = False,
|
||||||
model_name_or_path:str="BAAI/bge-reraker-large"
|
|
||||||
):
|
):
|
||||||
self.top_n=top_n
|
# self.top_n=top_n
|
||||||
self.model_name_or_path=model_name_or_path
|
# self.model_name_or_path=model_name_or_path
|
||||||
|
# self.device=device
|
||||||
|
# self.max_length=max_length
|
||||||
|
# self.batch_size=batch_size
|
||||||
|
# self.show_progress_bar=show_progress_bar
|
||||||
|
# self.num_workers=num_workers
|
||||||
|
# self.activation_fct=activation_fct
|
||||||
|
# self.apply_softmax=apply_softmax
|
||||||
|
|
||||||
self.device=device
|
self._model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device)
|
||||||
self.max_length=max_length
|
super().__init__(
|
||||||
self.batch_size=batch_size
|
top_n=top_n,
|
||||||
self.show_progress_bar=show_progress_bar
|
model_name_or_path=model_name_or_path,
|
||||||
self.num_workers=num_workers
|
device=device,
|
||||||
self.activation_fct=activation_fct
|
max_length=max_length,
|
||||||
self.apply_softmax=apply_softmax
|
batch_size=batch_size,
|
||||||
|
# show_progress_bar=show_progress_bar,
|
||||||
self.model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device)
|
num_workers=num_workers,
|
||||||
|
# activation_fct=activation_fct,
|
||||||
|
# apply_softmax=apply_softmax
|
||||||
|
)
|
||||||
|
|
||||||
def compress_documents(
|
def compress_documents(
|
||||||
self,
|
self,
|
||||||
|
|
@ -53,13 +78,14 @@ class LangchainReranker(BaseDocumentCompressor):
|
||||||
doc_list = list(documents)
|
doc_list = list(documents)
|
||||||
_docs = [d.page_content for d in doc_list]
|
_docs = [d.page_content for d in doc_list]
|
||||||
sentence_pairs = [[query,_doc] for _doc in _docs]
|
sentence_pairs = [[query,_doc] for _doc in _docs]
|
||||||
results = self.model.predict(sentences=sentence_pairs,
|
results = self._model.predict(sentences=sentence_pairs,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
show_progress_bar=self.show_progress_bar,
|
# show_progress_bar=self.show_progress_bar,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
activation_fct=self.activation_fct,
|
# activation_fct=self.activation_fct,
|
||||||
apply_softmax=self.apply_softmax,
|
# apply_softmax=self.apply_softmax,
|
||||||
convert_to_tensor=True)
|
convert_to_tensor=True
|
||||||
|
)
|
||||||
top_k = self.top_n if self.top_n < len(results) else len(results)
|
top_k = self.top_n if self.top_n < len(results) else len(results)
|
||||||
|
|
||||||
values, indices = results.topk(top_k)
|
values, indices = results.topk(top_k)
|
||||||
|
|
@ -68,4 +94,23 @@ class LangchainReranker(BaseDocumentCompressor):
|
||||||
doc = doc_list[index]
|
doc = doc_list[index]
|
||||||
doc.metadata["relevance_score"] = value
|
doc.metadata["relevance_score"] = value
|
||||||
final_results.append(doc)
|
final_results.append(doc)
|
||||||
return final_results
|
return final_results
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from configs import (LLM_MODELS,
|
||||||
|
VECTOR_SEARCH_TOP_K,
|
||||||
|
SCORE_THRESHOLD,
|
||||||
|
TEMPERATURE,
|
||||||
|
USE_RERANKER,
|
||||||
|
RERANKER_MODEL,
|
||||||
|
RERANKER_MAX_LENGTH,
|
||||||
|
MODEL_PATH)
|
||||||
|
from server.utils import embedding_device
|
||||||
|
if USE_RERANKER:
|
||||||
|
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
|
||||||
|
print("-----------------model path------------------")
|
||||||
|
print(reranker_model_path)
|
||||||
|
reranker_model = LangchainReranker(top_n=3,
|
||||||
|
device=embedding_device(),
|
||||||
|
max_length=RERANKER_MAX_LENGTH,
|
||||||
|
model_name_or_path=reranker_model_path
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue