diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 6af317b..8e25ea5 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -12,9 +12,9 @@ EMBEDDING_MODEL = "bge-large-zh" EMBEDDING_DEVICE = "auto" # 选用的reranker模型 -RERANKER_MDOEL = "bge-reranker-large" +RERANKER_MODEL = "bge-reranker-large" # 是否启用reranker模型 -USE_RERANKER = True +USE_RERANKER = False RERANKER_MAX_LENGTH = 1024 # 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置 EMBEDDING_KEYWORD_FILE = "keywords.txt" @@ -243,7 +243,7 @@ MODEL_PATH = { "Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat", }, "reranker":{ - "bge-reranker-large":BAAI/bge-reranker-large", + "bge-reranker-large":"BAAI/bge-reranker-large", "bge-reranker-base":"BAAI/bge-reranker-base", #TODO 增加在线reranker,如cohere } diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index f56f0d0..60956b4 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -87,14 +87,18 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", # 加入reranker if USE_RERANKER: reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large") + print("-----------------model path------------------") + print(reranker_model_path) reranker_model = LangchainReranker(top_n=top_k, device=embedding_device(), max_length=RERANKER_MAX_LENGTH, model_name_or_path=reranker_model_path ) + print(docs) docs = reranker_model.compress_documents(documents=docs, query=query) - + print("---------after rerank------------------") + print(docs) context = "\n".join([doc.page_content for doc in docs]) if len(docs) == 0: # 如果没有找到相关文档,使用empty模板 diff --git a/server/reranker/reranker.py b/server/reranker/reranker.py index fa4f84b..93ce3ba 100644 --- a/server/reranker/reranker.py +++ b/server/reranker/reranker.py @@ -1,35 +1,60 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +from typing import Any, List, Optional from sentence_transformers import CrossEncoder from typing import Optional, Sequence from langchain_core.documents import Document from langchain.callbacks.manager import Callbacks from langchain.retrievers.document_compressors.base import BaseDocumentCompressor +from llama_index.bridge.pydantic import Field,PrivateAttr class LangchainReranker(BaseDocumentCompressor): """Document compressor that uses `Cohere Rerank API`.""" - + model_name_or_path:str = Field() + _model: Any = PrivateAttr() + top_n:int= Field() + device:str=Field() + max_length:int=Field() + batch_size: int = Field() + # show_progress_bar: bool = None + num_workers: int = Field() + # activation_fct = None + # apply_softmax = False + def __init__(self, + model_name_or_path:str, top_n:int=3, device:str="cuda", max_length:int=1024, batch_size: int = 32, - show_progress_bar: bool = None, + # show_progress_bar: bool = None, num_workers: int = 0, - activation_fct = None, - apply_softmax = False, - model_name_or_path:str="BAAI/bge-reraker-large" + # activation_fct = None, + # apply_softmax = False, ): - self.top_n=top_n - self.model_name_or_path=model_name_or_path + # self.top_n=top_n + # self.model_name_or_path=model_name_or_path + # self.device=device + # self.max_length=max_length + # self.batch_size=batch_size + # self.show_progress_bar=show_progress_bar + # self.num_workers=num_workers + # self.activation_fct=activation_fct + # self.apply_softmax=apply_softmax - self.device=device - self.max_length=max_length - self.batch_size=batch_size - self.show_progress_bar=show_progress_bar - self.num_workers=num_workers - self.activation_fct=activation_fct - self.apply_softmax=apply_softmax - - self.model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device) + self._model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device) + super().__init__( + top_n=top_n, + model_name_or_path=model_name_or_path, + device=device, + max_length=max_length, + batch_size=batch_size, + # show_progress_bar=show_progress_bar, + num_workers=num_workers, + # activation_fct=activation_fct, + # apply_softmax=apply_softmax + ) def compress_documents( self, @@ -53,13 +78,14 @@ class LangchainReranker(BaseDocumentCompressor): doc_list = list(documents) _docs = [d.page_content for d in doc_list] sentence_pairs = [[query,_doc] for _doc in _docs] - results = self.model.predict(sentences=sentence_pairs, + results = self._model.predict(sentences=sentence_pairs, batch_size=self.batch_size, - show_progress_bar=self.show_progress_bar, + # show_progress_bar=self.show_progress_bar, num_workers=self.num_workers, - activation_fct=self.activation_fct, - apply_softmax=self.apply_softmax, - convert_to_tensor=True) + # activation_fct=self.activation_fct, + # apply_softmax=self.apply_softmax, + convert_to_tensor=True + ) top_k = self.top_n if self.top_n < len(results) else len(results) values, indices = results.topk(top_k) @@ -68,4 +94,23 @@ class LangchainReranker(BaseDocumentCompressor): doc = doc_list[index] doc.metadata["relevance_score"] = value final_results.append(doc) - return final_results \ No newline at end of file + return final_results +if __name__ == "__main__": + from configs import (LLM_MODELS, + VECTOR_SEARCH_TOP_K, + SCORE_THRESHOLD, + TEMPERATURE, + USE_RERANKER, + RERANKER_MODEL, + RERANKER_MAX_LENGTH, + MODEL_PATH) + from server.utils import embedding_device + if USE_RERANKER: + reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large") + print("-----------------model path------------------") + print(reranker_model_path) + reranker_model = LangchainReranker(top_n=3, + device=embedding_device(), + max_length=RERANKER_MAX_LENGTH, + model_name_or_path=reranker_model_path + ) \ No newline at end of file