将KnowledgeFile的file2text拆分成file2docs、docs2texts和file2text三个部分,在保持接口不变的情况下,实现:
1、支持chunk_size和chunk_overlap参数 2、支持自定义text_splitter 3、支持自定义docs 修复:csv文件不使用text_splitter
This commit is contained in:
parent
8475a5dfd3
commit
93b133f9ac
|
|
@ -14,9 +14,11 @@ import importlib
|
|||
from text_splitter import zh_title_enhance
|
||||
import langchain.document_loaders
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from pathlib import Path
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from server.utils import run_in_thread_pool
|
||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||
|
||||
|
||||
|
|
@ -186,75 +188,110 @@ class KnowledgeFile:
|
|||
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
||||
self.filepath = get_file_path(knowledge_base_name, filename)
|
||||
self.docs = None
|
||||
self.splited_docs = None
|
||||
self.document_loader_name = get_LoaderClass(self.ext)
|
||||
|
||||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
self.text_splitter_name = None
|
||||
|
||||
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False):
|
||||
if self.docs is not None and not refresh:
|
||||
return self.docs
|
||||
|
||||
print(f"{self.document_loader_name} used for {self.filepath}")
|
||||
try:
|
||||
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||
document_loaders_module = importlib.import_module('document_loaders')
|
||||
else:
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
if self.document_loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||
elif self.document_loader_name == "CSVLoader":
|
||||
loader = DocumentLoader(self.filepath, encoding="utf-8")
|
||||
elif self.document_loader_name == "JSONLoader":
|
||||
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
|
||||
elif self.document_loader_name == "CustomJSONLoader":
|
||||
loader = DocumentLoader(self.filepath, text_content=False)
|
||||
elif self.document_loader_name == "UnstructuredMarkdownLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements")
|
||||
elif self.document_loader_name == "UnstructuredHTMLLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements")
|
||||
else:
|
||||
loader = DocumentLoader(self.filepath)
|
||||
|
||||
if self.ext in ".csv":
|
||||
docs = loader.load()
|
||||
else:
|
||||
def file2docs(self, refresh: bool=False):
|
||||
if self.docs is None or refresh:
|
||||
print(f"{self.document_loader_name} used for {self.filepath}")
|
||||
try:
|
||||
if self.text_splitter_name is None:
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=OVERLAP_SIZE,
|
||||
)
|
||||
self.text_splitter_name = "SpacyTextSplitter"
|
||||
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||
document_loaders_module = importlib.import_module('document_loaders')
|
||||
else:
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=OVERLAP_SIZE)
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=OVERLAP_SIZE,
|
||||
)
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
if self.document_loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||
elif self.document_loader_name == "CSVLoader":
|
||||
loader = DocumentLoader(self.filepath, encoding="utf-8")
|
||||
elif self.document_loader_name == "JSONLoader":
|
||||
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
|
||||
elif self.document_loader_name == "CustomJSONLoader":
|
||||
loader = DocumentLoader(self.filepath, text_content=False)
|
||||
elif self.document_loader_name == "UnstructuredMarkdownLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements")
|
||||
elif self.document_loader_name == "UnstructuredHTMLLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements")
|
||||
else:
|
||||
loader = DocumentLoader(self.filepath)
|
||||
self.docs = loader.load()
|
||||
return self.docs
|
||||
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
print(docs[0])
|
||||
def make_text_splitter(
|
||||
self,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
):
|
||||
try:
|
||||
if self.text_splitter_name is None:
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
self.text_splitter_name = "SpacyTextSplitter"
|
||||
else:
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
return text_splitter
|
||||
|
||||
def docs2texts(
|
||||
self,
|
||||
docs: List[Document] = None,
|
||||
using_zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||
refresh: bool = False,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
text_splitter: TextSplitter = None,
|
||||
):
|
||||
docs = docs or self.file2docs(refresh=refresh)
|
||||
|
||||
if self.ext not in [".csv"]:
|
||||
if text_splitter is None:
|
||||
text_splitter = self.make_text_splitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
docs = text_splitter.split_documents(docs)
|
||||
|
||||
print(f"文档切分示例:{docs[0]}")
|
||||
if using_zh_title_enhance:
|
||||
docs = zh_title_enhance(docs)
|
||||
self.docs = docs
|
||||
return docs
|
||||
self.splited_docs = docs
|
||||
return self.splited_docs
|
||||
|
||||
def file2text(
|
||||
self,
|
||||
using_zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||
refresh: bool = False,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
text_splitter: TextSplitter = None,
|
||||
):
|
||||
if self.splited_docs is None or refresh:
|
||||
self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance,
|
||||
refresh=refresh,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
text_splitter=text_splitter)
|
||||
return self.splited_docs
|
||||
|
||||
def get_mtime(self):
|
||||
return os.path.getmtime(self.filepath)
|
||||
|
|
@ -263,36 +300,15 @@ class KnowledgeFile:
|
|||
return os.path.getsize(self.filepath)
|
||||
|
||||
|
||||
def run_in_thread_pool(
|
||||
func: Callable,
|
||||
params: List[Dict] = [],
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
||||
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
||||
'''
|
||||
tasks = []
|
||||
if pool is None:
|
||||
pool = ThreadPoolExecutor()
|
||||
|
||||
for kwargs in params:
|
||||
thread = pool.submit(func, **kwargs)
|
||||
tasks.append(thread)
|
||||
|
||||
for obj in as_completed(tasks):
|
||||
yield obj.result()
|
||||
|
||||
|
||||
def files2docs_in_thread(
|
||||
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
|
||||
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple,形式为(filename, kb_name)
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
利用多线程批量将文件转化成langchain Document.
|
||||
利用多线程批量将磁盘文件转化成langchain Document.
|
||||
生成器返回值为{(kb_name, file_name): docs}
|
||||
'''
|
||||
def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]:
|
||||
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
||||
try:
|
||||
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||
except Exception as e:
|
||||
|
|
@ -302,14 +318,26 @@ def files2docs_in_thread(
|
|||
for i, file in enumerate(files):
|
||||
kwargs = {}
|
||||
if isinstance(file, tuple) and len(file) >= 2:
|
||||
files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
|
||||
file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
|
||||
elif isinstance(file, dict):
|
||||
filename = file.pop("filename")
|
||||
kb_name = file.pop("kb_name")
|
||||
files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kwargs = file
|
||||
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kwargs["file"] = file
|
||||
kwargs_list.append(kwargs)
|
||||
|
||||
for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool):
|
||||
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
|
||||
yield result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pprint import pprint
|
||||
|
||||
kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
|
||||
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
|
||||
docs = kb_file.file2docs()
|
||||
pprint(docs[-1])
|
||||
|
||||
docs = kb_file.file2text()
|
||||
pprint(docs[-1])
|
||||
|
|
|
|||
|
|
@ -9,7 +9,11 @@ from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDIN
|
|||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
import os
|
||||
from server import model_workers
|
||||
from typing import Literal, Optional, Any
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Literal, Optional, Callable, Generator, Dict, Any
|
||||
|
||||
|
||||
thread_pool = ThreadPoolExecutor()
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
|
|
@ -305,3 +309,24 @@ def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "
|
|||
if device not in ["cuda", "mps", "cpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
||||
|
||||
def run_in_thread_pool(
|
||||
func: Callable,
|
||||
params: List[Dict] = [],
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
||||
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
||||
'''
|
||||
tasks = []
|
||||
pool = pool or thread_pool
|
||||
|
||||
for kwargs in params:
|
||||
thread = pool.submit(func, **kwargs)
|
||||
tasks.append(thread)
|
||||
|
||||
for obj in as_completed(tasks):
|
||||
yield obj.result()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue