diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 66d9d29..a454931 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -10,6 +10,8 @@ import numpy as np from utils import torch_gc from tqdm import tqdm from pypinyin import lazy_pinyin +from loader import UnstructuredPaddleImageLoader +from loader import UnstructuredPaddlePDFLoader DEVICE_ = EMBEDDING_DEVICE DEVICE_ID = "0" if torch.cuda.is_available() else None @@ -21,16 +23,31 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE): loader = UnstructuredFileLoader(filepath, mode="elements") docs = loader.load() elif filepath.lower().endswith(".pdf"): - loader = UnstructuredFileLoader(filepath, strategy="fast") + loader = UnstructuredPaddlePDFLoader(filepath) textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size) docs = loader.load_and_split(textsplitter) + elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"): + loader = UnstructuredPaddleImageLoader(filepath, mode="elements") + textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) + docs = loader.load_and_split(text_splitter=textsplitter) else: loader = UnstructuredFileLoader(filepath, mode="elements") textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) docs = loader.load_and_split(text_splitter=textsplitter) + write_check_file(filepath, docs) return docs +def write_check_file(filepath, docs): + fout = open('load_file.txt', 'a') + fout.write("filepath=%s,len=%s" % (filepath, len(docs))) + fout.write('\n') + for i in docs: + fout.write(str(i)) + fout.write('\n') + fout.close() + + def generate_prompt(related_docs: List[str], query: str, prompt_template=PROMPT_TEMPLATE) -> str: context = "\n".join([doc.page_content for doc in related_docs]) @@ -212,7 +229,7 @@ class LocalDocQA: if not vs_path or not one_title or not one_conent: logger.info("知识库添加错误,请确认知识库名字、标题、内容是否正确!") return None, [one_title] - docs = [Document(page_content=one_conent+"\n", metadata={"source": one_title})] + docs = [Document(page_content=one_conent + "\n", metadata={"source": one_title})] if not one_content_segmentation: text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) docs = text_splitter.split_documents(docs) diff --git a/docs/test.pdf b/docs/test.pdf new file mode 100644 index 0000000..3a137ad Binary files /dev/null and b/docs/test.pdf differ diff --git a/img/test.jpg b/img/test.jpg new file mode 100644 index 0000000..70c199b Binary files /dev/null and b/img/test.jpg differ diff --git a/loader/__init__.py b/loader/__init__.py new file mode 100644 index 0000000..e9a7aea --- /dev/null +++ b/loader/__init__.py @@ -0,0 +1,2 @@ +from .image_loader import UnstructuredPaddleImageLoader +from .pdf_loader import UnstructuredPaddlePDFLoader diff --git a/loader/image_loader.py b/loader/image_loader.py new file mode 100644 index 0000000..5a1552a --- /dev/null +++ b/loader/image_loader.py @@ -0,0 +1,28 @@ +"""Loader that loads image files.""" +from typing import List + +from langchain.document_loaders.unstructured import UnstructuredFileLoader +from paddleocr import PaddleOCR +import os + + +class UnstructuredPaddleImageLoader(UnstructuredFileLoader): + """Loader that uses unstructured to load image files, such as PNGs and JPGs.""" + + def _get_elements(self) -> List: + def image_ocr_txt(filepath, dir_path="tmp_files"): + if not os.path.exists(dir_path): + os.makedirs(dir_path) + filename = os.path.split(filepath)[-1] + ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False) + result = ocr.ocr(img=filepath) + + ocr_result = [i[1][0] for line in result for i in line] + txt_file_path = os.path.join(dir_path, "%s.txt" % (filename)) + with open(txt_file_path, 'w', encoding='utf-8') as fout: + fout.write("\n".join(ocr_result)) + return txt_file_path + + txt_file_path = image_ocr_txt(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(filename=txt_file_path, **self.unstructured_kwargs) diff --git a/loader/pdf_loader.py b/loader/pdf_loader.py new file mode 100644 index 0000000..cc623d9 --- /dev/null +++ b/loader/pdf_loader.py @@ -0,0 +1,44 @@ +"""Loader that loads image files.""" +from typing import List + +from langchain.document_loaders.unstructured import UnstructuredFileLoader +from paddleocr import PaddleOCR +import os +import fitz + + +class UnstructuredPaddlePDFLoader(UnstructuredFileLoader): + """Loader that uses unstructured to load image files, such as PNGs and JPGs.""" + + def _get_elements(self) -> List: + def pdf_ocr_txt(filepath, dir_path="tmp_files"): + if not os.path.exists(dir_path): + os.makedirs(dir_path) + filename = os.path.split(filepath)[-1] + ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False) + doc = fitz.open(filepath) + txt_file_path = os.path.join(dir_path, "%s.txt" % (filename)) + img_name = './img/.tmp.png' + with open(txt_file_path, 'w', encoding='utf-8') as fout: + + for i in range(doc.page_count): + page = doc[i] + text = page.get_text("") + fout.write(text) + fout.write("\n") + + img_list = page.get_images() + for img in img_list: + pix = fitz.Pixmap(doc, img[0]) + + pix.save(img_name) + + result = ocr.ocr(img_name) + ocr_result = [i[1][0] for line in result for i in line] + fout.write("\n".join(ocr_result)) + os.remove(img_name) + return txt_file_path + + txt_file_path = pdf_ocr_txt(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(filename=txt_file_path, **self.unstructured_kwargs) diff --git a/requirements.txt b/requirements.txt index d7b2e4e..31d44e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ +pymupdf +paddlepaddle==2.4.2 +paddleocr langchain==0.0.146 transformers==4.27.1 unstructured[local-inference] diff --git a/test_image.py b/test_image.py new file mode 100644 index 0000000..ed60890 --- /dev/null +++ b/test_image.py @@ -0,0 +1,12 @@ +from configs.model_config import * +import nltk + +nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path + +filepath = "./img/test.jpg" +from loader import UnstructuredPaddleImageLoader + +loader = UnstructuredPaddleImageLoader(filepath, mode="elements") +docs = loader.load() +for doc in docs: + print(doc) diff --git a/test_pdf.py b/test_pdf.py new file mode 100644 index 0000000..32dcb34 --- /dev/null +++ b/test_pdf.py @@ -0,0 +1,12 @@ +from configs.model_config import * +import nltk + +nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path + +filepath = "docs/test.pdf" +from loader import UnstructuredPaddlePDFLoader + +loader = UnstructuredPaddlePDFLoader(filepath, mode="elements") +docs = loader.load() +for doc in docs: + print(doc)