update local_doc_qa.py
This commit is contained in:
parent
6d1523728b
commit
2681728329
|
|
@ -1,6 +1,6 @@
|
|||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.document_loaders import UnstructuredFileLoader
|
||||
from langchain.document_loaders import UnstructuredFileLoader, TextLoader
|
||||
from configs.model_config import *
|
||||
import datetime
|
||||
from textsplitter import ChineseTextSplitter
|
||||
|
|
@ -10,8 +10,7 @@ import numpy as np
|
|||
from utils import torch_gc
|
||||
from tqdm import tqdm
|
||||
from pypinyin import lazy_pinyin
|
||||
from loader import UnstructuredPaddleImageLoader
|
||||
from loader import UnstructuredPaddlePDFLoader
|
||||
from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader
|
||||
from models.base import (BaseAnswer,
|
||||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
|
|
@ -21,14 +20,14 @@ from models.loader import LoaderCheckPoint
|
|||
import models.shared as shared
|
||||
|
||||
|
||||
|
||||
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
||||
if filepath.lower().endswith(".md"):
|
||||
loader = UnstructuredFileLoader(filepath, mode="elements")
|
||||
docs = loader.load()
|
||||
elif filepath.lower().endswith(".txt"):
|
||||
loader = UnstructuredFileLoader(filepath, mode="elements")
|
||||
docs = loader.load()
|
||||
loader = TextLoader(filepath, autodetect_encoding=True)
|
||||
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
||||
docs = loader.load_and_split(textsplitter)
|
||||
elif filepath.lower().endswith(".pdf"):
|
||||
loader = UnstructuredPaddlePDFLoader(filepath)
|
||||
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
|
||||
|
|
|
|||
Loading…
Reference in New Issue