Merge pull request #2435 from chatchat-space/reranker
新增特性:使用Reranker模型对召回语句进行重排
This commit is contained in:
commit
d77f778e0d
|
|
@ -11,6 +11,11 @@ EMBEDDING_MODEL = "bge-large-zh"
|
|||
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||
EMBEDDING_DEVICE = "auto"
|
||||
|
||||
# 选用的reranker模型
|
||||
RERANKER_MODEL = "bge-reranker-large"
|
||||
# 是否启用reranker模型
|
||||
USE_RERANKER = False
|
||||
RERANKER_MAX_LENGTH = 1024
|
||||
# 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置
|
||||
EMBEDDING_KEYWORD_FILE = "keywords.txt"
|
||||
EMBEDDING_MODEL_OUTPUT_PATH = "output"
|
||||
|
|
@ -19,8 +24,9 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output"
|
|||
# 列表中第一个模型将作为 API 和 WEBUI 的默认模型。
|
||||
# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。
|
||||
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
|
||||
# chatglm3-6b输出角色标签<|user|>及自问自答的问题是由于fschat=0.2.33并未正确适配chatglm3的对话模板
|
||||
# 如需修正该问题,需修改fschat的源码,详细步骤见项目wiki->常见问题->Q20.
|
||||
|
||||
# chatglm3-6b输出角色标签<|user|>及自问自答的问题详见项目wiki->常见问题->Q20.
|
||||
|
||||
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat",
|
||||
|
||||
# AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0])
|
||||
|
|
@ -236,6 +242,11 @@ MODEL_PATH = {
|
|||
|
||||
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
|
||||
},
|
||||
"reranker":{
|
||||
"bge-reranker-large":"BAAI/bge-reranker-large",
|
||||
"bge-reranker-base":"BAAI/bge-reranker-base",
|
||||
#TODO 增加在线reranker,如cohere
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,14 @@
|
|||
from fastapi import Body, Request
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
|
||||
from configs import (LLM_MODELS,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SCORE_THRESHOLD,
|
||||
TEMPERATURE,
|
||||
USE_RERANKER,
|
||||
RERANKER_MODEL,
|
||||
RERANKER_MAX_LENGTH,
|
||||
MODEL_PATH)
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from langchain.chains import LLMChain
|
||||
|
|
@ -14,8 +21,8 @@ from server.knowledge_base.kb_service.base import KBServiceFactory
|
|||
import json
|
||||
from urllib.parse import urlencode
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
|
||||
|
||||
from server.reranker.reranker import LangchainReranker
|
||||
from server.utils import embedding_device
|
||||
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
|
|
@ -76,7 +83,24 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
knowledge_base_name=knowledge_base_name,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold)
|
||||
|
||||
# 加入reranker
|
||||
if USE_RERANKER:
|
||||
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
|
||||
print("-----------------model path------------------")
|
||||
print(reranker_model_path)
|
||||
reranker_model = LangchainReranker(top_n=top_k,
|
||||
device=embedding_device(),
|
||||
max_length=RERANKER_MAX_LENGTH,
|
||||
model_name_or_path=reranker_model_path
|
||||
)
|
||||
print(docs)
|
||||
docs = reranker_model.compress_documents(documents=docs,
|
||||
query=query)
|
||||
print("---------after rerank------------------")
|
||||
print(docs)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,19 +1,116 @@
|
|||
from langchain.retrievers.document_compressors import CohereRerank
|
||||
from llama_index.postprocessor import SentenceTransformerRerank
|
||||
from sentence_transformers import SentenceTransformer,CrossEncoder
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
from typing import Any, List, Optional
|
||||
from sentence_transformers import CrossEncoder
|
||||
from typing import Optional, Sequence
|
||||
from langchain_core.documents import Document
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||
from llama_index.bridge.pydantic import Field,PrivateAttr
|
||||
|
||||
model_path = "/root/autodl-tmp/models/bge-reranker-large/"
|
||||
instruction = "为这个句子生成表示以用于检索相关文章:"
|
||||
reranker = SentenceTransformerRerank(
|
||||
top_n=5,
|
||||
model="local:"+model_path,
|
||||
)
|
||||
class LangchainReranker(BaseDocumentCompressor):
|
||||
"""Document compressor that uses `Cohere Rerank API`."""
|
||||
model_name_or_path:str = Field()
|
||||
_model: Any = PrivateAttr()
|
||||
top_n:int= Field()
|
||||
device:str=Field()
|
||||
max_length:int=Field()
|
||||
batch_size: int = Field()
|
||||
# show_progress_bar: bool = None
|
||||
num_workers: int = Field()
|
||||
# activation_fct = None
|
||||
# apply_softmax = False
|
||||
|
||||
def __init__(self,
|
||||
model_name_or_path:str,
|
||||
top_n:int=3,
|
||||
device:str="cuda",
|
||||
max_length:int=1024,
|
||||
batch_size: int = 32,
|
||||
# show_progress_bar: bool = None,
|
||||
num_workers: int = 0,
|
||||
# activation_fct = None,
|
||||
# apply_softmax = False,
|
||||
):
|
||||
# self.top_n=top_n
|
||||
# self.model_name_or_path=model_name_or_path
|
||||
# self.device=device
|
||||
# self.max_length=max_length
|
||||
# self.batch_size=batch_size
|
||||
# self.show_progress_bar=show_progress_bar
|
||||
# self.num_workers=num_workers
|
||||
# self.activation_fct=activation_fct
|
||||
# self.apply_softmax=apply_softmax
|
||||
|
||||
reranker_model = SentenceTransformer(model_name_or_path=model_path,device="cuda")
|
||||
self._model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device)
|
||||
super().__init__(
|
||||
top_n=top_n,
|
||||
model_name_or_path=model_name_or_path,
|
||||
device=device,
|
||||
max_length=max_length,
|
||||
batch_size=batch_size,
|
||||
# show_progress_bar=show_progress_bar,
|
||||
num_workers=num_workers,
|
||||
# activation_fct=activation_fct,
|
||||
# apply_softmax=apply_softmax
|
||||
)
|
||||
|
||||
reranker_ce = CrossEncoder(model_name=model_path,device="cuda",max_length=1024)
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""
|
||||
Compress documents using Cohere's rerank API.
|
||||
|
||||
reranker_ce.predict([[],[]])
|
||||
|
||||
print("Load reranker")
|
||||
Args:
|
||||
documents: A sequence of documents to compress.
|
||||
query: The query to use for compressing the documents.
|
||||
callbacks: Callbacks to run during the compression process.
|
||||
|
||||
Returns:
|
||||
A sequence of compressed documents.
|
||||
"""
|
||||
if len(documents) == 0: # to avoid empty api call
|
||||
return []
|
||||
doc_list = list(documents)
|
||||
_docs = [d.page_content for d in doc_list]
|
||||
sentence_pairs = [[query,_doc] for _doc in _docs]
|
||||
results = self._model.predict(sentences=sentence_pairs,
|
||||
batch_size=self.batch_size,
|
||||
# show_progress_bar=self.show_progress_bar,
|
||||
num_workers=self.num_workers,
|
||||
# activation_fct=self.activation_fct,
|
||||
# apply_softmax=self.apply_softmax,
|
||||
convert_to_tensor=True
|
||||
)
|
||||
top_k = self.top_n if self.top_n < len(results) else len(results)
|
||||
|
||||
values, indices = results.topk(top_k)
|
||||
final_results = []
|
||||
for value, index in zip(values,indices):
|
||||
doc = doc_list[index]
|
||||
doc.metadata["relevance_score"] = value
|
||||
final_results.append(doc)
|
||||
return final_results
|
||||
if __name__ == "__main__":
|
||||
from configs import (LLM_MODELS,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SCORE_THRESHOLD,
|
||||
TEMPERATURE,
|
||||
USE_RERANKER,
|
||||
RERANKER_MODEL,
|
||||
RERANKER_MAX_LENGTH,
|
||||
MODEL_PATH)
|
||||
from server.utils import embedding_device
|
||||
if USE_RERANKER:
|
||||
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
|
||||
print("-----------------model path------------------")
|
||||
print(reranker_model_path)
|
||||
reranker_model = LangchainReranker(top_n=3,
|
||||
device=embedding_device(),
|
||||
max_length=RERANKER_MAX_LENGTH,
|
||||
model_name_or_path=reranker_model_path
|
||||
)
|
||||
Loading…
Reference in New Issue