Langchain-Chatchat/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py

49 lines
1.6 KiB
Python

from __future__ import annotations
from langchain.retrievers import EnsembleRetriever
from langchain.vectorstores import VectorStore
from langchain_community.retrievers import BM25Retriever
from langchain_core.retrievers import BaseRetriever
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
class EnsembleRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = retriever
@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int | float,
):
faiss_retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": score_threshold, "k": top_k},
)
# TODO: 换个不用torch的实现方式
# from cutword.cutword import Cutter
import jieba
# cutter = Cutter()
docs = list(vectorstore.docstore._dict.values())
bm25_retriever = BM25Retriever.from_documents(
docs,
preprocess_func=jieba.lcut_for_search,
)
bm25_retriever.k = top_k
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
)
return EnsembleRetrieverService(retriever=ensemble_retriever, top_k=top_k)
def get_relevant_documents(self, query: str):
return self.retriever.get_relevant_documents(query)[: self.top_k]