From 5891f94c887096234368368a638dfb1daddd5746 Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Thu, 21 Dec 2023 16:04:15 +0800 Subject: [PATCH] temporarily add reranker --- configs/model_config.py.example | 15 +++++- server/chat/knowledge_base_chat.py | 26 ++++++++-- server/reranker/reranker.py | 78 +++++++++++++++++++++++++----- 3 files changed, 101 insertions(+), 18 deletions(-) diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 659788c..6af317b 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -11,6 +11,11 @@ EMBEDDING_MODEL = "bge-large-zh" # Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 EMBEDDING_DEVICE = "auto" +# 选用的reranker模型 +RERANKER_MDOEL = "bge-reranker-large" +# 是否启用reranker模型 +USE_RERANKER = True +RERANKER_MAX_LENGTH = 1024 # 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置 EMBEDDING_KEYWORD_FILE = "keywords.txt" EMBEDDING_MODEL_OUTPUT_PATH = "output" @@ -19,8 +24,9 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output" # 列表中第一个模型将作为 API 和 WEBUI 的默认模型。 # 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。 # 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。 -# chatglm3-6b输出角色标签<|user|>及自问自答的问题是由于fschat=0.2.33并未正确适配chatglm3的对话模板 -# 如需修正该问题,需修改fschat的源码,详细步骤见项目wiki->常见问题->Q20. + +# chatglm3-6b输出角色标签<|user|>及自问自答的问题详见项目wiki->常见问题->Q20. + LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat", # AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0]) @@ -236,6 +242,11 @@ MODEL_PATH = { "Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat", }, + "reranker":{ + "bge-reranker-large":BAAI/bge-reranker-large", + "bge-reranker-base":"BAAI/bge-reranker-base", + #TODO 增加在线reranker,如cohere + } } diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 162e24d..f56f0d0 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,7 +1,14 @@ from fastapi import Body, Request from sse_starlette.sse import EventSourceResponse from fastapi.concurrency import run_in_threadpool -from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE) +from configs import (LLM_MODELS, + VECTOR_SEARCH_TOP_K, + SCORE_THRESHOLD, + TEMPERATURE, + USE_RERANKER, + RERANKER_MODEL, + RERANKER_MAX_LENGTH, + MODEL_PATH) from server.utils import wrap_done, get_ChatOpenAI from server.utils import BaseResponse, get_prompt_template from langchain.chains import LLMChain @@ -14,8 +21,8 @@ from server.knowledge_base.kb_service.base import KBServiceFactory import json from urllib.parse import urlencode from server.knowledge_base.kb_doc_api import search_docs - - +from server.reranker.reranker import LangchainReranker +from server.utils import embedding_device async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), @@ -76,7 +83,20 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", knowledge_base_name=knowledge_base_name, top_k=top_k, score_threshold=score_threshold) + + # 加入reranker + if USE_RERANKER: + reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large") + reranker_model = LangchainReranker(top_n=top_k, + device=embedding_device(), + max_length=RERANKER_MAX_LENGTH, + model_name_or_path=reranker_model_path + ) + docs = reranker_model.compress_documents(documents=docs, + query=query) + context = "\n".join([doc.page_content for doc in docs]) + if len(docs) == 0: # 如果没有找到相关文档,使用empty模板 prompt_template = get_prompt_template("knowledge_base_chat", "empty") else: diff --git a/server/reranker/reranker.py b/server/reranker/reranker.py index 5bc702d..fa4f84b 100644 --- a/server/reranker/reranker.py +++ b/server/reranker/reranker.py @@ -1,19 +1,71 @@ -from langchain.retrievers.document_compressors import CohereRerank -from llama_index.postprocessor import SentenceTransformerRerank -from sentence_transformers import SentenceTransformer,CrossEncoder +from sentence_transformers import CrossEncoder +from typing import Optional, Sequence +from langchain_core.documents import Document +from langchain.callbacks.manager import Callbacks +from langchain.retrievers.document_compressors.base import BaseDocumentCompressor -model_path = "/root/autodl-tmp/models/bge-reranker-large/" -instruction = "为这个句子生成表示以用于检索相关文章:" -reranker = SentenceTransformerRerank( - top_n=5, - model="local:"+model_path, -) +class LangchainReranker(BaseDocumentCompressor): + """Document compressor that uses `Cohere Rerank API`.""" -reranker_model = SentenceTransformer(model_name_or_path=model_path,device="cuda") + def __init__(self, + top_n:int=3, + device:str="cuda", + max_length:int=1024, + batch_size: int = 32, + show_progress_bar: bool = None, + num_workers: int = 0, + activation_fct = None, + apply_softmax = False, + model_name_or_path:str="BAAI/bge-reraker-large" + ): + self.top_n=top_n + self.model_name_or_path=model_name_or_path -reranker_ce = CrossEncoder(model_name=model_path,device="cuda",max_length=1024) + self.device=device + self.max_length=max_length + self.batch_size=batch_size + self.show_progress_bar=show_progress_bar + self.num_workers=num_workers + self.activation_fct=activation_fct + self.apply_softmax=apply_softmax -reranker_ce.predict([[],[]]) + self.model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device) -print("Load reranker") + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + """ + Compress documents using Cohere's rerank API. + Args: + documents: A sequence of documents to compress. + query: The query to use for compressing the documents. + callbacks: Callbacks to run during the compression process. + + Returns: + A sequence of compressed documents. + """ + if len(documents) == 0: # to avoid empty api call + return [] + doc_list = list(documents) + _docs = [d.page_content for d in doc_list] + sentence_pairs = [[query,_doc] for _doc in _docs] + results = self.model.predict(sentences=sentence_pairs, + batch_size=self.batch_size, + show_progress_bar=self.show_progress_bar, + num_workers=self.num_workers, + activation_fct=self.activation_fct, + apply_softmax=self.apply_softmax, + convert_to_tensor=True) + top_k = self.top_n if self.top_n < len(results) else len(results) + + values, indices = results.topk(top_k) + final_results = [] + for value, index in zip(values,indices): + doc = doc_list[index] + doc.metadata["relevance_score"] = value + final_results.append(doc) + return final_results \ No newline at end of file