将KnowledgeFile的file2text拆分成file2docs、docs2texts和file2text三个部分,在保持接口不变的情况下,实现:
1、支持chunk_size和chunk_overlap参数 2、支持自定义text_splitter 3、支持自定义docs 修复:csv文件不使用text_splitter
This commit is contained in:
parent
8475a5dfd3
commit
93b133f9ac
|
|
@ -14,9 +14,11 @@ import importlib
|
||||||
from text_splitter import zh_title_enhance
|
from text_splitter import zh_title_enhance
|
||||||
import langchain.document_loaders
|
import langchain.document_loaders
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.text_splitter import TextSplitter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from server.utils import run_in_thread_pool
|
||||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -186,75 +188,110 @@ class KnowledgeFile:
|
||||||
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
||||||
self.filepath = get_file_path(knowledge_base_name, filename)
|
self.filepath = get_file_path(knowledge_base_name, filename)
|
||||||
self.docs = None
|
self.docs = None
|
||||||
|
self.splited_docs = None
|
||||||
self.document_loader_name = get_LoaderClass(self.ext)
|
self.document_loader_name = get_LoaderClass(self.ext)
|
||||||
|
|
||||||
# TODO: 增加依据文件格式匹配text_splitter
|
# TODO: 增加依据文件格式匹配text_splitter
|
||||||
self.text_splitter_name = None
|
self.text_splitter_name = None
|
||||||
|
|
||||||
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False):
|
def file2docs(self, refresh: bool=False):
|
||||||
if self.docs is not None and not refresh:
|
if self.docs is None or refresh:
|
||||||
return self.docs
|
print(f"{self.document_loader_name} used for {self.filepath}")
|
||||||
|
|
||||||
print(f"{self.document_loader_name} used for {self.filepath}")
|
|
||||||
try:
|
|
||||||
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
|
||||||
document_loaders_module = importlib.import_module('document_loaders')
|
|
||||||
else:
|
|
||||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
|
||||||
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
|
||||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
|
||||||
if self.document_loader_name == "UnstructuredFileLoader":
|
|
||||||
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
|
||||||
elif self.document_loader_name == "CSVLoader":
|
|
||||||
loader = DocumentLoader(self.filepath, encoding="utf-8")
|
|
||||||
elif self.document_loader_name == "JSONLoader":
|
|
||||||
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
|
|
||||||
elif self.document_loader_name == "CustomJSONLoader":
|
|
||||||
loader = DocumentLoader(self.filepath, text_content=False)
|
|
||||||
elif self.document_loader_name == "UnstructuredMarkdownLoader":
|
|
||||||
loader = DocumentLoader(self.filepath, mode="elements")
|
|
||||||
elif self.document_loader_name == "UnstructuredHTMLLoader":
|
|
||||||
loader = DocumentLoader(self.filepath, mode="elements")
|
|
||||||
else:
|
|
||||||
loader = DocumentLoader(self.filepath)
|
|
||||||
|
|
||||||
if self.ext in ".csv":
|
|
||||||
docs = loader.load()
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
if self.text_splitter_name is None:
|
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
document_loaders_module = importlib.import_module('document_loaders')
|
||||||
TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter")
|
|
||||||
text_splitter = TextSplitter(
|
|
||||||
pipeline="zh_core_web_sm",
|
|
||||||
chunk_size=CHUNK_SIZE,
|
|
||||||
chunk_overlap=OVERLAP_SIZE,
|
|
||||||
)
|
|
||||||
self.text_splitter_name = "SpacyTextSplitter"
|
|
||||||
else:
|
else:
|
||||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||||
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
|
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
|
||||||
text_splitter = TextSplitter(
|
|
||||||
chunk_size=CHUNK_SIZE,
|
|
||||||
chunk_overlap=OVERLAP_SIZE)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||||
text_splitter = TextSplitter(
|
if self.document_loader_name == "UnstructuredFileLoader":
|
||||||
chunk_size=CHUNK_SIZE,
|
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||||
chunk_overlap=OVERLAP_SIZE,
|
elif self.document_loader_name == "CSVLoader":
|
||||||
)
|
loader = DocumentLoader(self.filepath, encoding="utf-8")
|
||||||
|
elif self.document_loader_name == "JSONLoader":
|
||||||
|
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
|
||||||
|
elif self.document_loader_name == "CustomJSONLoader":
|
||||||
|
loader = DocumentLoader(self.filepath, text_content=False)
|
||||||
|
elif self.document_loader_name == "UnstructuredMarkdownLoader":
|
||||||
|
loader = DocumentLoader(self.filepath, mode="elements")
|
||||||
|
elif self.document_loader_name == "UnstructuredHTMLLoader":
|
||||||
|
loader = DocumentLoader(self.filepath, mode="elements")
|
||||||
|
else:
|
||||||
|
loader = DocumentLoader(self.filepath)
|
||||||
|
self.docs = loader.load()
|
||||||
|
return self.docs
|
||||||
|
|
||||||
docs = loader.load_and_split(text_splitter)
|
def make_text_splitter(
|
||||||
print(docs[0])
|
self,
|
||||||
|
chunk_size: int = CHUNK_SIZE,
|
||||||
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if self.text_splitter_name is None:
|
||||||
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
|
TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter")
|
||||||
|
text_splitter = TextSplitter(
|
||||||
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
)
|
||||||
|
self.text_splitter_name = "SpacyTextSplitter"
|
||||||
|
else:
|
||||||
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
|
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
|
||||||
|
text_splitter = TextSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
|
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||||
|
text_splitter = TextSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
)
|
||||||
|
return text_splitter
|
||||||
|
|
||||||
|
def docs2texts(
|
||||||
|
self,
|
||||||
|
docs: List[Document] = None,
|
||||||
|
using_zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
|
refresh: bool = False,
|
||||||
|
chunk_size: int = CHUNK_SIZE,
|
||||||
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
|
text_splitter: TextSplitter = None,
|
||||||
|
):
|
||||||
|
docs = docs or self.file2docs(refresh=refresh)
|
||||||
|
|
||||||
|
if self.ext not in [".csv"]:
|
||||||
|
if text_splitter is None:
|
||||||
|
text_splitter = self.make_text_splitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
|
docs = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
|
print(f"文档切分示例:{docs[0]}")
|
||||||
if using_zh_title_enhance:
|
if using_zh_title_enhance:
|
||||||
docs = zh_title_enhance(docs)
|
docs = zh_title_enhance(docs)
|
||||||
self.docs = docs
|
self.splited_docs = docs
|
||||||
return docs
|
return self.splited_docs
|
||||||
|
|
||||||
|
def file2text(
|
||||||
|
self,
|
||||||
|
using_zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
|
refresh: bool = False,
|
||||||
|
chunk_size: int = CHUNK_SIZE,
|
||||||
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
|
text_splitter: TextSplitter = None,
|
||||||
|
):
|
||||||
|
if self.splited_docs is None or refresh:
|
||||||
|
self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance,
|
||||||
|
refresh=refresh,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
text_splitter=text_splitter)
|
||||||
|
return self.splited_docs
|
||||||
|
|
||||||
def get_mtime(self):
|
def get_mtime(self):
|
||||||
return os.path.getmtime(self.filepath)
|
return os.path.getmtime(self.filepath)
|
||||||
|
|
@ -263,36 +300,15 @@ class KnowledgeFile:
|
||||||
return os.path.getsize(self.filepath)
|
return os.path.getsize(self.filepath)
|
||||||
|
|
||||||
|
|
||||||
def run_in_thread_pool(
|
|
||||||
func: Callable,
|
|
||||||
params: List[Dict] = [],
|
|
||||||
pool: ThreadPoolExecutor = None,
|
|
||||||
) -> Generator:
|
|
||||||
'''
|
|
||||||
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
|
||||||
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
|
||||||
'''
|
|
||||||
tasks = []
|
|
||||||
if pool is None:
|
|
||||||
pool = ThreadPoolExecutor()
|
|
||||||
|
|
||||||
for kwargs in params:
|
|
||||||
thread = pool.submit(func, **kwargs)
|
|
||||||
tasks.append(thread)
|
|
||||||
|
|
||||||
for obj in as_completed(tasks):
|
|
||||||
yield obj.result()
|
|
||||||
|
|
||||||
|
|
||||||
def files2docs_in_thread(
|
def files2docs_in_thread(
|
||||||
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
|
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple,形式为(filename, kb_name)
|
||||||
pool: ThreadPoolExecutor = None,
|
pool: ThreadPoolExecutor = None,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
'''
|
'''
|
||||||
利用多线程批量将文件转化成langchain Document.
|
利用多线程批量将磁盘文件转化成langchain Document.
|
||||||
生成器返回值为{(kb_name, file_name): docs}
|
生成器返回值为{(kb_name, file_name): docs}
|
||||||
'''
|
'''
|
||||||
def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]:
|
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
||||||
try:
|
try:
|
||||||
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -302,14 +318,26 @@ def files2docs_in_thread(
|
||||||
for i, file in enumerate(files):
|
for i, file in enumerate(files):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if isinstance(file, tuple) and len(file) >= 2:
|
if isinstance(file, tuple) and len(file) >= 2:
|
||||||
files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
|
file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
|
||||||
elif isinstance(file, dict):
|
elif isinstance(file, dict):
|
||||||
filename = file.pop("filename")
|
filename = file.pop("filename")
|
||||||
kb_name = file.pop("kb_name")
|
kb_name = file.pop("kb_name")
|
||||||
files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
|
||||||
kwargs = file
|
kwargs = file
|
||||||
|
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||||
kwargs["file"] = file
|
kwargs["file"] = file
|
||||||
kwargs_list.append(kwargs)
|
kwargs_list.append(kwargs)
|
||||||
|
|
||||||
for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool):
|
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
|
||||||
|
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
|
||||||
|
docs = kb_file.file2docs()
|
||||||
|
pprint(docs[-1])
|
||||||
|
|
||||||
|
docs = kb_file.file2text()
|
||||||
|
pprint(docs[-1])
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,11 @@ from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDIN
|
||||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||||
import os
|
import os
|
||||||
from server import model_workers
|
from server import model_workers
|
||||||
from typing import Literal, Optional, Any
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from typing import Literal, Optional, Callable, Generator, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
thread_pool = ThreadPoolExecutor()
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(BaseModel):
|
class BaseResponse(BaseModel):
|
||||||
|
|
@ -305,3 +309,24 @@ def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "
|
||||||
if device not in ["cuda", "mps", "cpu"]:
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
device = detect_device()
|
device = detect_device()
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def run_in_thread_pool(
|
||||||
|
func: Callable,
|
||||||
|
params: List[Dict] = [],
|
||||||
|
pool: ThreadPoolExecutor = None,
|
||||||
|
) -> Generator:
|
||||||
|
'''
|
||||||
|
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
||||||
|
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
||||||
|
'''
|
||||||
|
tasks = []
|
||||||
|
pool = pool or thread_pool
|
||||||
|
|
||||||
|
for kwargs in params:
|
||||||
|
thread = pool.submit(func, **kwargs)
|
||||||
|
tasks.append(thread)
|
||||||
|
|
||||||
|
for obj in as_completed(tasks):
|
||||||
|
yield obj.result()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue