fix typos
This commit is contained in:
parent
349de9b955
commit
719e2713ed
|
|
@ -25,7 +25,7 @@ faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accel
|
||||||
accelerate==0.24.1
|
accelerate==0.24.1
|
||||||
spacy~=3.7.2
|
spacy~=3.7.2
|
||||||
PyMuPDF~=1.23.8
|
PyMuPDF~=1.23.8
|
||||||
rapidocr_onnxruntime~=1.3.8
|
rapidocr_onnxruntime==1.3.8
|
||||||
requests>=2.31.0
|
requests>=2.31.0
|
||||||
pathlib>=1.0.1
|
pathlib>=1.0.1
|
||||||
pytest>=7.4.3
|
pytest>=7.4.3
|
||||||
|
|
@ -35,16 +35,17 @@ markdownify>=0.11.6
|
||||||
tiktoken~=0.5.2
|
tiktoken~=0.5.2
|
||||||
tqdm>=4.66.1
|
tqdm>=4.66.1
|
||||||
websockets>=12.0
|
websockets>=12.0
|
||||||
numpy~=1.26.2
|
numpy~=1.24.4
|
||||||
pandas~=2.1.4
|
pandas~=2.0.3
|
||||||
einops>=0.7.0
|
einops>=0.7.0
|
||||||
transformers_stream_generator==0.0.4
|
transformers_stream_generator==0.0.4
|
||||||
vllm==0.2.6; sys_platform == "linux"
|
vllm==0.2.6; sys_platform == "linux"
|
||||||
httpx[brotli,http2,socks]~=0.25.2
|
httpx[brotli,http2,socks]==0.25.2
|
||||||
|
llama-index
|
||||||
|
|
||||||
# optional document loaders
|
# optional document loaders
|
||||||
|
|
||||||
rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
|
# rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
|
||||||
jq>=1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
|
jq>=1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
|
||||||
# html2text # for .enex files
|
# html2text # for .enex files
|
||||||
beautifulsoup4~=4.12.2 # for .mhtml files
|
beautifulsoup4~=4.12.2 # for .mhtml files
|
||||||
|
|
@ -52,7 +53,7 @@ pysrt~=1.1.2
|
||||||
|
|
||||||
# Online api libs dependencies
|
# Online api libs dependencies
|
||||||
|
|
||||||
zhipuai>=1.0.7,<=2.0.0 # zhipu
|
zhipuai>=1.0.7, <=2.0.0 # zhipu
|
||||||
dashscope>=1.13.6 # qwen
|
dashscope>=1.13.6 # qwen
|
||||||
# volcengine>=1.0.119 # fangzhou
|
# volcengine>=1.0.119 # fangzhou
|
||||||
|
|
||||||
|
|
@ -75,5 +76,4 @@ streamlit-option-menu>=0.3.6
|
||||||
streamlit-chatbox==1.1.11
|
streamlit-chatbox==1.1.11
|
||||||
streamlit-modal>=0.1.0
|
streamlit-modal>=0.1.0
|
||||||
streamlit-aggrid>=0.3.4.post3
|
streamlit-aggrid>=0.3.4.post3
|
||||||
httpx[brotli,http2,socks]>=0.25.2
|
|
||||||
watchdog>=3.0.0
|
watchdog>=3.0.0
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
from sentence_transformers import CrossEncoder
|
from sentence_transformers import CrossEncoder
|
||||||
|
|
@ -7,26 +8,28 @@ from typing import Optional, Sequence
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||||
from llama_index.bridge.pydantic import Field,PrivateAttr
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
||||||
|
|
||||||
|
|
||||||
class LangchainReranker(BaseDocumentCompressor):
|
class LangchainReranker(BaseDocumentCompressor):
|
||||||
"""Document compressor that uses `Cohere Rerank API`."""
|
"""Document compressor that uses `Cohere Rerank API`."""
|
||||||
model_name_or_path:str = Field()
|
model_name_or_path: str = Field()
|
||||||
_model: Any = PrivateAttr()
|
_model: Any = PrivateAttr()
|
||||||
top_n:int= Field()
|
top_n: int = Field()
|
||||||
device:str=Field()
|
device: str = Field()
|
||||||
max_length:int=Field()
|
max_length: int = Field()
|
||||||
batch_size: int = Field()
|
batch_size: int = Field()
|
||||||
# show_progress_bar: bool = None
|
# show_progress_bar: bool = None
|
||||||
num_workers: int = Field()
|
num_workers: int = Field()
|
||||||
|
|
||||||
# activation_fct = None
|
# activation_fct = None
|
||||||
# apply_softmax = False
|
# apply_softmax = False
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_name_or_path:str,
|
model_name_or_path: str,
|
||||||
top_n:int=3,
|
top_n: int = 3,
|
||||||
device:str="cuda",
|
device: str = "cuda",
|
||||||
max_length:int=1024,
|
max_length: int = 1024,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
# show_progress_bar: bool = None,
|
# show_progress_bar: bool = None,
|
||||||
num_workers: int = 0,
|
num_workers: int = 0,
|
||||||
|
|
@ -43,7 +46,7 @@ class LangchainReranker(BaseDocumentCompressor):
|
||||||
# self.activation_fct=activation_fct
|
# self.activation_fct=activation_fct
|
||||||
# self.apply_softmax=apply_softmax
|
# self.apply_softmax=apply_softmax
|
||||||
|
|
||||||
self._model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device)
|
self._model = CrossEncoder(model_name=model_name_or_path, max_length=1024, device=device)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
top_n=top_n,
|
top_n=top_n,
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
|
|
@ -77,7 +80,7 @@ class LangchainReranker(BaseDocumentCompressor):
|
||||||
return []
|
return []
|
||||||
doc_list = list(documents)
|
doc_list = list(documents)
|
||||||
_docs = [d.page_content for d in doc_list]
|
_docs = [d.page_content for d in doc_list]
|
||||||
sentence_pairs = [[query,_doc] for _doc in _docs]
|
sentence_pairs = [[query, _doc] for _doc in _docs]
|
||||||
results = self._model.predict(sentences=sentence_pairs,
|
results = self._model.predict(sentences=sentence_pairs,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
# show_progress_bar=self.show_progress_bar,
|
# show_progress_bar=self.show_progress_bar,
|
||||||
|
|
@ -90,11 +93,13 @@ class LangchainReranker(BaseDocumentCompressor):
|
||||||
|
|
||||||
values, indices = results.topk(top_k)
|
values, indices = results.topk(top_k)
|
||||||
final_results = []
|
final_results = []
|
||||||
for value, index in zip(values,indices):
|
for value, index in zip(values, indices):
|
||||||
doc = doc_list[index]
|
doc = doc_list[index]
|
||||||
doc.metadata["relevance_score"] = value
|
doc.metadata["relevance_score"] = value
|
||||||
final_results.append(doc)
|
final_results.append(doc)
|
||||||
return final_results
|
return final_results
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from configs import (LLM_MODELS,
|
from configs import (LLM_MODELS,
|
||||||
VECTOR_SEARCH_TOP_K,
|
VECTOR_SEARCH_TOP_K,
|
||||||
|
|
@ -105,8 +110,9 @@ if __name__ == "__main__":
|
||||||
RERANKER_MAX_LENGTH,
|
RERANKER_MAX_LENGTH,
|
||||||
MODEL_PATH)
|
MODEL_PATH)
|
||||||
from server.utils import embedding_device
|
from server.utils import embedding_device
|
||||||
|
|
||||||
if USE_RERANKER:
|
if USE_RERANKER:
|
||||||
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
|
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large")
|
||||||
print("-----------------model path------------------")
|
print("-----------------model path------------------")
|
||||||
print(reranker_model_path)
|
print(reranker_model_path)
|
||||||
reranker_model = LangchainReranker(top_n=3,
|
reranker_model = LangchainReranker(top_n=3,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue