diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 25016f7..c96ede5 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -1,6 +1,6 @@ from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.vectorstores import FAISS -from langchain.document_loaders import UnstructuredFileLoader, TextLoader +from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader from configs.model_config import * import datetime from textsplitter import ChineseTextSplitter @@ -74,6 +74,9 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE): loader = UnstructuredPaddleImageLoader(filepath, mode="elements") textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) docs = loader.load_and_split(text_splitter=textsplitter) + elif filepath.lower().endswith(".csv"): + loader = CSVLoader(filepath) + docs = loader.load() else: loader = UnstructuredFileLoader(filepath, mode="elements") textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)