Langchain-Chatchat/server/reranker/reranker.py

123 lines
4.6 KiB
Python
Raw Normal View History

import os
import sys
2023-12-31 20:13:14 +08:00
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from typing import Any, List, Optional
2023-12-21 16:04:15 +08:00
from sentence_transformers import CrossEncoder
from typing import Optional, Sequence
from langchain_core.documents import Document
from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
2023-12-31 20:13:14 +08:00
from llama_index.bridge.pydantic import Field, PrivateAttr
2023-12-06 20:42:27 +08:00
2023-12-21 16:04:15 +08:00
class LangchainReranker(BaseDocumentCompressor):
"""Document compressor that uses `Cohere Rerank API`."""
2023-12-31 20:13:14 +08:00
model_name_or_path: str = Field()
_model: Any = PrivateAttr()
2023-12-31 20:13:14 +08:00
top_n: int = Field()
device: str = Field()
max_length: int = Field()
batch_size: int = Field()
# show_progress_bar: bool = None
num_workers: int = Field()
2023-12-31 20:13:14 +08:00
# activation_fct = None
# apply_softmax = False
2023-12-31 20:13:14 +08:00
2023-12-21 16:04:15 +08:00
def __init__(self,
2023-12-31 20:13:14 +08:00
model_name_or_path: str,
top_n: int = 3,
device: str = "cuda",
max_length: int = 1024,
batch_size: int = 32,
# show_progress_bar: bool = None,
num_workers: int = 0,
# activation_fct = None,
# apply_softmax = False,
):
# self.top_n=top_n
# self.model_name_or_path=model_name_or_path
# self.device=device
# self.max_length=max_length
# self.batch_size=batch_size
# self.show_progress_bar=show_progress_bar
# self.num_workers=num_workers
# self.activation_fct=activation_fct
# self.apply_softmax=apply_softmax
2023-12-06 20:42:27 +08:00
2024-01-04 18:02:43 +08:00
self._model = CrossEncoder(model_name=model_name_or_path, max_length=512, device=device)
super().__init__(
2023-12-31 20:13:14 +08:00
top_n=top_n,
model_name_or_path=model_name_or_path,
device=device,
max_length=max_length,
batch_size=batch_size,
# show_progress_bar=show_progress_bar,
num_workers=num_workers,
# activation_fct=activation_fct,
# apply_softmax=apply_softmax
)
2023-12-06 20:42:27 +08:00
2023-12-21 16:04:15 +08:00
def compress_documents(
2023-12-31 20:13:14 +08:00
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
2023-12-21 16:04:15 +08:00
) -> Sequence[Document]:
"""
Compress documents using Cohere's rerank API.
2023-12-06 20:42:27 +08:00
2023-12-21 16:04:15 +08:00
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
if len(documents) == 0: # to avoid empty api call
return []
doc_list = list(documents)
_docs = [d.page_content for d in doc_list]
2023-12-31 20:13:14 +08:00
sentence_pairs = [[query, _doc] for _doc in _docs]
results = self._model.predict(sentences=sentence_pairs,
2023-12-31 20:13:14 +08:00
batch_size=self.batch_size,
# show_progress_bar=self.show_progress_bar,
num_workers=self.num_workers,
# activation_fct=self.activation_fct,
# apply_softmax=self.apply_softmax,
convert_to_tensor=True
)
2023-12-21 16:04:15 +08:00
top_k = self.top_n if self.top_n < len(results) else len(results)
2023-12-31 20:13:14 +08:00
2023-12-21 16:04:15 +08:00
values, indices = results.topk(top_k)
final_results = []
2023-12-31 20:13:14 +08:00
for value, index in zip(values, indices):
2023-12-21 16:04:15 +08:00
doc = doc_list[index]
doc.metadata["relevance_score"] = value
final_results.append(doc)
return final_results
2023-12-31 20:13:14 +08:00
if __name__ == "__main__":
from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH)
from server.utils import embedding_device
2023-12-31 20:13:14 +08:00
if USE_RERANKER:
2023-12-31 20:13:14 +08:00
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=3,
2023-12-31 20:13:14 +08:00
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)