temporarily add reranker

This commit is contained in:
hzg0601 2023-12-21 16:04:15 +08:00
parent 60510ff2f0
commit 5891f94c88
3 changed files with 101 additions and 18 deletions

View File

@ -11,6 +11,11 @@ EMBEDDING_MODEL = "bge-large-zh"
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 # Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "auto" EMBEDDING_DEVICE = "auto"
# 选用的reranker模型
RERANKER_MDOEL = "bge-reranker-large"
# 是否启用reranker模型
USE_RERANKER = True
RERANKER_MAX_LENGTH = 1024
# 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置 # 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置
EMBEDDING_KEYWORD_FILE = "keywords.txt" EMBEDDING_KEYWORD_FILE = "keywords.txt"
EMBEDDING_MODEL_OUTPUT_PATH = "output" EMBEDDING_MODEL_OUTPUT_PATH = "output"
@ -19,8 +24,9 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output"
# 列表中第一个模型将作为 API 和 WEBUI 的默认模型。 # 列表中第一个模型将作为 API 和 WEBUI 的默认模型。
# 在这里我们使用目前主流的两个离线模型其中chatglm3-6b 为默认加载模型。 # 在这里我们使用目前主流的两个离线模型其中chatglm3-6b 为默认加载模型。
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。 # 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
# chatglm3-6b输出角色标签<|user|>及自问自答的问题是由于fschat=0.2.33并未正确适配chatglm3的对话模板
# 如需修正该问题需修改fschat的源码详细步骤见项目wiki->常见问题->Q20. # chatglm3-6b输出角色标签<|user|>及自问自答的问题详见项目wiki->常见问题->Q20.
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat", LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat",
# AgentLM模型的名称 (可以不指定指定之后就锁定进入Agent之后的Chain的模型不指定就是LLM_MODELS[0]) # AgentLM模型的名称 (可以不指定指定之后就锁定进入Agent之后的Chain的模型不指定就是LLM_MODELS[0])
@ -236,6 +242,11 @@ MODEL_PATH = {
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat", "Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
}, },
"reranker":{
"bge-reranker-large":BAAI/bge-reranker-large",
"bge-reranker-base":"BAAI/bge-reranker-base",
#TODO 增加在线reranker如cohere
}
} }

View File

@ -1,7 +1,14 @@
from fastapi import Body, Request from fastapi import Body, Request
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE) from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH)
from server.utils import wrap_done, get_ChatOpenAI from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain from langchain.chains import LLMChain
@ -14,8 +21,8 @@ from server.knowledge_base.kb_service.base import KBServiceFactory
import json import json
from urllib.parse import urlencode from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs from server.knowledge_base.kb_doc_api import search_docs
from server.reranker.reranker import LangchainReranker
from server.utils import embedding_device
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
@ -76,7 +83,20 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
knowledge_base_name=knowledge_base_name, knowledge_base_name=knowledge_base_name,
top_k=top_k, top_k=top_k,
score_threshold=score_threshold) score_threshold=score_threshold)
# 加入reranker
if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
reranker_model = LangchainReranker(top_n=top_k,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)
docs = reranker_model.compress_documents(documents=docs,
query=query)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: # 如果没有找到相关文档使用empty模板 if len(docs) == 0: # 如果没有找到相关文档使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty") prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else: else:

View File

@ -1,19 +1,71 @@
from langchain.retrievers.document_compressors import CohereRerank from sentence_transformers import CrossEncoder
from llama_index.postprocessor import SentenceTransformerRerank from typing import Optional, Sequence
from sentence_transformers import SentenceTransformer,CrossEncoder from langchain_core.documents import Document
from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
model_path = "/root/autodl-tmp/models/bge-reranker-large/" class LangchainReranker(BaseDocumentCompressor):
instruction = "为这个句子生成表示以用于检索相关文章:" """Document compressor that uses `Cohere Rerank API`."""
reranker = SentenceTransformerRerank(
top_n=5,
model="local:"+model_path,
)
reranker_model = SentenceTransformer(model_name_or_path=model_path,device="cuda") def __init__(self,
top_n:int=3,
device:str="cuda",
max_length:int=1024,
batch_size: int = 32,
show_progress_bar: bool = None,
num_workers: int = 0,
activation_fct = None,
apply_softmax = False,
model_name_or_path:str="BAAI/bge-reraker-large"
):
self.top_n=top_n
self.model_name_or_path=model_name_or_path
reranker_ce = CrossEncoder(model_name=model_path,device="cuda",max_length=1024) self.device=device
self.max_length=max_length
self.batch_size=batch_size
self.show_progress_bar=show_progress_bar
self.num_workers=num_workers
self.activation_fct=activation_fct
self.apply_softmax=apply_softmax
reranker_ce.predict([[],[]]) self.model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device)
print("Load reranker") def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using Cohere's rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
if len(documents) == 0: # to avoid empty api call
return []
doc_list = list(documents)
_docs = [d.page_content for d in doc_list]
sentence_pairs = [[query,_doc] for _doc in _docs]
results = self.model.predict(sentences=sentence_pairs,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
num_workers=self.num_workers,
activation_fct=self.activation_fct,
apply_softmax=self.apply_softmax,
convert_to_tensor=True)
top_k = self.top_n if self.top_n < len(results) else len(results)
values, indices = results.topk(top_k)
final_results = []
for value, index in zip(values,indices):
doc = doc_list[index]
doc.metadata["relevance_score"] = value
final_results.append(doc)
return final_results