From e8a37ff4c7d35e1a4e84b3e13b4388a3fdb41975 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sat, 20 May 2023 01:24:35 +0800 Subject: [PATCH] update loader.py --- chains/local_doc_qa.py | 15 +++++++++------ loader/pdf_loader.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 5ba2f7e..9514978 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -26,6 +26,9 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE): if filepath.lower().endswith(".md"): loader = UnstructuredFileLoader(filepath, mode="elements") docs = loader.load() + elif filepath.lower().endswith(".txt"): + loader = UnstructuredFileLoader(filepath, mode="elements") + docs = loader.load() elif filepath.lower().endswith(".pdf"): loader = UnstructuredPaddlePDFLoader(filepath) textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size) @@ -47,13 +50,13 @@ def write_check_file(filepath, docs): if not os.path.exists(folder_path): os.makedirs(folder_path) fp = os.path.join(folder_path, 'load_file.txt') - fout = open(fp, 'a') - fout.write("filepath=%s,len=%s" % (filepath, len(docs))) - fout.write('\n') - for i in docs: - fout.write(str(i)) + with open(fp, 'a+', encoding='utf-8') as fout: + fout.write("filepath=%s,len=%s" % (filepath, len(docs))) fout.write('\n') - fout.close() + for i in docs: + fout.write(str(i)) + fout.write('\n') + fout.close() def generate_prompt(related_docs: List[str], query: str, diff --git a/loader/pdf_loader.py b/loader/pdf_loader.py index cb972a9..3414121 100644 --- a/loader/pdf_loader.py +++ b/loader/pdf_loader.py @@ -19,7 +19,7 @@ class UnstructuredPaddlePDFLoader(UnstructuredFileLoader): ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False) doc = fitz.open(filepath) txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename)) - img_name = os.path.join(full_dir_path, '.tmp.png') + img_name = os.path.join(full_dir_path, 'tmp.png') with open(txt_file_path, 'w', encoding='utf-8') as fout: for i in range(doc.page_count):