update textsplitter and pdfloader
This commit is contained in:
parent
222689ed5b
commit
02b9d57072
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config)
|
||||
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE)
|
||||
from functools import lru_cache
|
||||
import sys
|
||||
from text_splitter import zh_title_enhance
|
||||
|
|
@ -45,6 +45,8 @@ LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg
|
|||
'.doc', '.docx', '.epub', '.odt', '.pdf',
|
||||
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
|
||||
"CSVLoader": [".csv"],
|
||||
"PyPDFLoader": [".pdf"],
|
||||
|
||||
}
|
||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||
|
||||
|
|
@ -70,15 +72,34 @@ class KnowledgeFile:
|
|||
self.document_loader_name = get_LoaderClass(self.ext)
|
||||
|
||||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
self.text_splitter_name = "CharacterTextSplitter"
|
||||
self.text_splitter_name = None
|
||||
|
||||
def file2text(self, using_zh_title_enhance):
|
||||
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
|
||||
loader = DocumentLoader(self.filepath)
|
||||
|
||||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
try:
|
||||
if self.text_splitter_name is None:
|
||||
TextSplitter = getattr(sys.modules['langchain.text_splitter'], "SpacyTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=OVERLAP_SIZE,
|
||||
)
|
||||
else:
|
||||
TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name)
|
||||
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=200)
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=50)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
TextSplitter = getattr(sys.modules['langchain.text_splitter'], "RecursiveCharacterTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=OVERLAP_SIZE,
|
||||
)
|
||||
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
if using_zh_title_enhance:
|
||||
docs = zh_title_enhance(docs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue