diff --git a/document_loaders/__init__.py b/document_loaders/__init__.py new file mode 100644 index 0000000..a4d6b28 --- /dev/null +++ b/document_loaders/__init__.py @@ -0,0 +1,2 @@ +from .mypdfloader import RapidOCRPDFLoader +from .myimgloader import RapidOCRLoader \ No newline at end of file diff --git a/document_loaders/myimgloader.py b/document_loaders/myimgloader.py new file mode 100644 index 0000000..8648192 --- /dev/null +++ b/document_loaders/myimgloader.py @@ -0,0 +1,25 @@ +from typing import List +from langchain.document_loaders.unstructured import UnstructuredFileLoader + + +class RapidOCRLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def img2text(filepath): + from rapidocr_onnxruntime import RapidOCR + resp = "" + ocr = RapidOCR() + result, _ = ocr(filepath) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) + return resp + + text = img2text(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == "__main__": + loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg") + docs = loader.load() + print(docs) diff --git a/document_loaders/mypdfloader.py b/document_loaders/mypdfloader.py new file mode 100644 index 0000000..71e063d --- /dev/null +++ b/document_loaders/mypdfloader.py @@ -0,0 +1,37 @@ +from typing import List +from langchain.document_loaders.unstructured import UnstructuredFileLoader + + +class RapidOCRPDFLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def pdf2text(filepath): + import fitz + from rapidocr_onnxruntime import RapidOCR + import numpy as np + ocr = RapidOCR() + doc = fitz.open(filepath) + resp = "" + for page in doc: + # TODO: 依据文本与图片顺序调整处理方式 + text = page.get_text("") + resp += text + "\n" + + img_list = page.get_images() + for img in img_list: + pix = fitz.Pixmap(doc, img[0]) + img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) + result, _ = ocr(img_array) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) + return resp + + text = pdf2text(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == "__main__": + loader = RapidOCRPDFLoader(file_path="../tests/samples/ocr_test.pdf") + docs = loader.load() + print(docs) diff --git a/requirements.txt b/requirements.txt index e40f665..4271f3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,8 @@ SQLAlchemy==2.0.19 faiss-cpu accelerate spacy +PyMuPDF==1.22.5 +rapidocr_onnxruntime>=1.3.1 # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 diff --git a/requirements_api.txt b/requirements_api.txt index 58dbc0c..bdecf3c 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -16,6 +16,8 @@ faiss-cpu nltk accelerate spacy +PyMuPDF==1.22.5 +rapidocr_onnxruntime>=1.3.1 # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 8cab754..8582c9c 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -87,7 +87,8 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], "UnstructuredMarkdownLoader": ['.md'], "CustomJSONLoader": [".json"], "CSVLoader": [".csv"], - "PyPDFLoader": [".pdf"], + "RapidOCRPDFLoader": [".pdf"], + "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], "UnstructuredFileLoader": ['.eml', '.msg', '.rst', '.rtf', '.txt', '.xml', '.doc', '.docx', '.epub', '.odt', @@ -196,7 +197,10 @@ class KnowledgeFile: print(f"{self.document_loader_name} used for {self.filepath}") try: - document_loaders_module = importlib.import_module('langchain.document_loaders') + if self.document_loader_name in []: + document_loaders_module = importlib.import_module('document_loaders') + else: + document_loaders_module = importlib.import_module('langchain.document_loaders') DocumentLoader = getattr(document_loaders_module, self.document_loader_name) except Exception as e: print(e) diff --git a/tests/samples/ocr_test.jpg b/tests/samples/ocr_test.jpg new file mode 100644 index 0000000..70c199b Binary files /dev/null and b/tests/samples/ocr_test.jpg differ diff --git a/tests/samples/ocr_test.pdf b/tests/samples/ocr_test.pdf new file mode 100644 index 0000000..3a137ad Binary files /dev/null and b/tests/samples/ocr_test.pdf differ