update textsplitter and pdfloader
This commit is contained in:
parent
222689ed5b
commit
02b9d57072
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config)
|
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE)
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import sys
|
import sys
|
||||||
from text_splitter import zh_title_enhance
|
from text_splitter import zh_title_enhance
|
||||||
|
|
@ -45,6 +45,8 @@ LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg
|
||||||
'.doc', '.docx', '.epub', '.odt', '.pdf',
|
'.doc', '.docx', '.epub', '.odt', '.pdf',
|
||||||
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
|
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
|
||||||
"CSVLoader": [".csv"],
|
"CSVLoader": [".csv"],
|
||||||
|
"PyPDFLoader": [".pdf"],
|
||||||
|
|
||||||
}
|
}
|
||||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||||
|
|
||||||
|
|
@ -70,15 +72,34 @@ class KnowledgeFile:
|
||||||
self.document_loader_name = get_LoaderClass(self.ext)
|
self.document_loader_name = get_LoaderClass(self.ext)
|
||||||
|
|
||||||
# TODO: 增加依据文件格式匹配text_splitter
|
# TODO: 增加依据文件格式匹配text_splitter
|
||||||
self.text_splitter_name = "CharacterTextSplitter"
|
self.text_splitter_name = None
|
||||||
|
|
||||||
def file2text(self, using_zh_title_enhance):
|
def file2text(self, using_zh_title_enhance):
|
||||||
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
|
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
|
||||||
loader = DocumentLoader(self.filepath)
|
loader = DocumentLoader(self.filepath)
|
||||||
|
|
||||||
# TODO: 增加依据文件格式匹配text_splitter
|
# TODO: 增加依据文件格式匹配text_splitter
|
||||||
|
try:
|
||||||
|
if self.text_splitter_name is None:
|
||||||
|
TextSplitter = getattr(sys.modules['langchain.text_splitter'], "SpacyTextSplitter")
|
||||||
|
text_splitter = TextSplitter(
|
||||||
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
|
)
|
||||||
|
else:
|
||||||
TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name)
|
TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name)
|
||||||
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=200)
|
text_splitter = TextSplitter(
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
chunk_overlap=50)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
TextSplitter = getattr(sys.modules['langchain.text_splitter'], "RecursiveCharacterTextSplitter")
|
||||||
|
text_splitter = TextSplitter(
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
docs = loader.load_and_split(text_splitter)
|
docs = loader.load_and_split(text_splitter)
|
||||||
if using_zh_title_enhance:
|
if using_zh_title_enhance:
|
||||||
docs = zh_title_enhance(docs)
|
docs = zh_title_enhance(docs)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue