update textsplitter and pdfloader

This commit is contained in:
imClumsyPanda 2023-08-09 23:36:28 +08:00
parent 222689ed5b
commit 02b9d57072
1 changed files with 25 additions and 4 deletions

View File

@ -1,6 +1,6 @@
import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config)
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE)
from functools import lru_cache
import sys
from text_splitter import zh_title_enhance
@ -45,6 +45,8 @@ LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg
'.doc', '.docx', '.epub', '.odt', '.pdf',
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
"CSVLoader": [".csv"],
"PyPDFLoader": [".pdf"],
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
@ -70,15 +72,34 @@ class KnowledgeFile:
self.document_loader_name = get_LoaderClass(self.ext)
# TODO: 增加依据文件格式匹配text_splitter
self.text_splitter_name = "CharacterTextSplitter"
self.text_splitter_name = None
def file2text(self, using_zh_title_enhance):
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
loader = DocumentLoader(self.filepath)
# TODO: 增加依据文件格式匹配text_splitter
try:
if self.text_splitter_name is None:
TextSplitter = getattr(sys.modules['langchain.text_splitter'], "SpacyTextSplitter")
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
)
else:
TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name)
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=200)
text_splitter = TextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=50)
except Exception as e:
print(e)
TextSplitter = getattr(sys.modules['langchain.text_splitter'], "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
)
docs = loader.load_and_split(text_splitter)
if using_zh_title_enhance:
docs = zh_title_enhance(docs)