diff --git a/README.md b/README.md new file mode 100644 index 0000000..6d3e648 --- /dev/null +++ b/README.md @@ -0,0 +1,160 @@ +# PDF文本提取工具py39 + +安装:python安装好后,可以直接pip install --no-index --find-links=./packages -r requirements.txt 再进行 pip install --no-index --find-links=./packages -r requirements2.txt 安装后忽略报错。环境配置成功 + + +## 项目概述 +这是一个基于Python的PDF文本提取工具,支持批量处理PDF文件,自动识别文档中的特定区域(如姓名、金额等),并进行OCR文字识别。该工具使用本地模型进行处理,无需联网,可方便地集成到桌面应用程序中,提供高效的文档处理能力。 + +## 文件结构 + +``` +项目根目录/ +├── pdf_to_table.py # 主程序 +├── requirements.txt # 依赖列表 +├── requirements2.txt # 依赖列表 +├── model/ # 模型文件夹(必须) +│ ├── layout/ # 版面检测模型 +│ └── best.pt # YOLO分类模型(固定名称) +│ └── PP-OCRv4_mobile_det # OCR检测模型 +│ └── PP-OCRv4_mobile_rec # OCR识别模型 +├── packages/ # 依赖包文件夹 +├── ocr.yaml # OCR配置文件(必须) +├── poppler/ # Poppler工具(可选,推荐) +│ └── bin/ # Poppler可执行文件 +└── output/ # 默认输出目录,初始没有,运行后会自动创建 +``` + +## 安装说明 + +### 1. Poppler安装 +Poppler用于PDF到图像的转换,程序会自动检测多个位置的Poppler安装: + +**方法1:放置在程序目录(推荐)** +1. 下载Poppler for Windows: https://github.com/oschwartz10612/poppler-windows/releases/ +2. 解压到程序根目录下的`poppler`文件夹(或`poppler/bin`子目录) + +**方法2:系统安装** +- Windows: 安装到任意位置,程序会自动检测常见路径(C:/poppler/bin、Program Files等) +- Linux/Mac: 使用包管理器安装(apt-get、brew等) + +**方法3:环境变量** +- 设置环境变量`POPPLER_PATH`指向Poppler的bin目录 + +程序会按以上优先级自动检测Poppler位置,无需手动配置。 + +### 2. 模型准备 +1. 确保版面检测模型位于`model/layout/`目录下 +2. 确保YOLO分类模型文件命名为`best.pt`并位于`model/`目录下 +3. 确保OCR配置文件`ocr.yaml`位于项目根目录下 + +## 使用说明 + +### 程序集成 + +```python +from pdf_to_table import process_pdf_to_table + +# 处理PDF文件夹 +stats = process_pdf_to_table( + pdf_folder="PDF文件夹路径", + output_folder="输出文件夹路径", + confidence_threshold=0.6 # 可选,默认0.6 +) + +# 处理结果统计 +print(f"目标图片数: {stats['target_images']}") +print(f"成功提取区域的图片数: {stats['successful_extractions']}") +print(f"name区域数: {stats['total_name_regions']}, money区域数: {stats['total_money_regions']}") +``` + +### 独立运行 +也可以直接运行脚本,使用默认参数: + +```bash +python pdf_to_table.py +``` + +默认会处理`input_pdfs`文件夹中的PDF文件,并将结果保存到`output`文件夹。 + +### 错误处理与日志 +```python +import logging + +# 配置日志 +logging.basicConfig(filename='pdf_process.log', level=logging.INFO) + +try: + stats = process_pdf_to_table(...) + # 处理完成后的操作 +except Exception as e: + logging.error(f"处理失败: {str(e)}") + # 错误处理 +``` + +## 输出结构 +处理完成后,输出文件夹结构如下: +``` +输出文件夹/ +├── images/ # 转换后的图片 +├── results/ # 处理结果 +│ ├── classification_results.csv # 分类结果 +│ ├── processing_results.csv # 处理结果 +│ └── [图片序号]/ # 每个图片的处理结果 +│ ├── [图片名]_all_regions.jpg # 区域可视化 +│ ├── [图片名]_name_[序号].jpg # 姓名区域图片 +│ ├── [图片名]_money_[序号].jpg # 金额区域图片 +│ ├── [图片名]_name_[序号]_texts.txt # 姓名区域文本 +│ ├── [图片名]_money_[序号]_texts.txt # 金额区域文本 +│ └── [图片名]_regions.txt # 合并后的文本结果 +└── rotated_images/ # 旋转后的图片 +``` + +## 注意事项 + +1. 文件路径与结构 + - 模型文件名称和位置固定(best.pt、layout目录) + - Poppler会自动检测,无需手动指定路径 + - 输入和输出路径可以使用相对或绝对路径 + - 打包应用时需要包含所有模型文件和配置文件 + +2. 本地模型使用 + - 程序使用本地模型,不需要网络连接 + - 确保模型文件完整且未损坏 + - 版面检测模型必须位于`model/layout/`目录 + - OCR配置文件必须位于根目录 + +3. 性能考虑 + - 处理大量PDF文件时可能需要较长时间 + - 建议在单独线程中运行处理任务 + - 显示进度条以提高用户体验 + +4. 集成建议 + - 将模型文件和配置文件打包到应用程序中 + - 处理前检查文件和模型是否存在 + - 使用try/except捕获异常并记录日志 + +## 常见问题 + +1. PDF转换失败 + - 确认Poppler正确安装(程序会自动检测多个位置) + - 检查PDF文件是否有效 + - 查看日志了解详细错误信息 + +2. 模型加载失败 + - 检查model目录下是否包含完整的模型文件 + - 确认模型文件名称为best.pt + - 确认ocr.yaml配置文件存在 + +3. 识别质量问题 + - 调整confidence_threshold值(默认0.6) + - 检查图片质量,必要时提高DPI + - 检查区域裁剪参数 + +4. 旋转处理问题 + - 确保YOLO模型支持的分类类别: no, yes, yes_90, yes_-90 + - yes_90表示需逆时针旋转90度,yes_-90表示需顺时针旋转90度 + +--- + +该工具支持四种分类:`no`(非目标文档), `yes`(正向文档), `yes_90`(需逆时针旋转90度), `yes_-90`(需顺时针旋转90度),并支持三种版面区域类型:`title`, `name`, `money`(仅提取name和money区域)。 \ No newline at end of file diff --git a/pdf_to_table.py b/pdf_to_table.py new file mode 100644 index 0000000..06c8bc7 --- /dev/null +++ b/pdf_to_table.py @@ -0,0 +1,985 @@ +import os +import warnings +import json +import logging +from tqdm import tqdm +from pdf2image import convert_from_path +from ultralytics import YOLO +import pandas as pd +from datetime import datetime +import paddlex as pdx +import cv2 +import numpy as np +import sys +from pathlib import Path +import requests +import zipfile +import shutil +import tempfile +import platform +import re +import hashlib +import pickle + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger(__name__) + +# 减少PaddlePaddle的日志输出 +os.environ['FLAGS_call_stack_level'] = '2' + +# 忽略警告信息 +warnings.filterwarnings('ignore') + +# 全局变量用于存储已加载的模型 +_loaded_layout_model = None +_loaded_ocr_pipeline = None + +# 缓存目录 +CACHE_DIR = os.environ.get('OCR_CACHE_DIR', os.path.join(tempfile.gettempdir(), 'ocr_cache')) +# 是否启用缓存 +USE_CACHE = os.environ.get('OCR_USE_CACHE', '1') == '1' + +def get_cache_path(key): + """获取缓存文件路径""" + if not os.path.exists(CACHE_DIR): + os.makedirs(CACHE_DIR, exist_ok=True) + + # 使用MD5生成缓存文件名 + hash_obj = hashlib.md5(key.encode('utf-8')) + return os.path.join(CACHE_DIR, f"{hash_obj.hexdigest()}.cache") + +def get_from_cache(key): + """从缓存获取数据""" + if not USE_CACHE: + return None + + cache_path = get_cache_path(key) + if os.path.exists(cache_path): + try: + with open(cache_path, 'rb') as f: + return pickle.load(f) + except Exception as e: + logger.error(f"读取缓存失败: {str(e)}") + + return None + +def save_to_cache(key, data): + """保存数据到缓存""" + if not USE_CACHE: + return + + cache_path = get_cache_path(key) + try: + with open(cache_path, 'wb') as f: + pickle.dump(data, f) + except Exception as e: + logger.error(f"保存缓存失败: {str(e)}") + +def resource_path(relative_path): + """获取资源绝对路径,支持开发环境和PyInstaller打包后的环境""" + try: + # PyInstaller creates a temp folder and stores path in _MEIPASS + base_path = getattr(sys, '_MEIPASS', os.path.dirname(os.path.abspath(__file__))) + return os.path.join(base_path, relative_path) + except Exception as e: + logger.error(f"获取资源路径时出错: {str(e)}") + return os.path.join(os.path.dirname(os.path.abspath(__file__)), relative_path) + +def install_poppler_from_local(local_zip_path, install_dir=None): + """从本地ZIP文件安装Poppler""" + try: + if not os.path.exists(local_zip_path): + logger.error(f"本地Poppler ZIP文件不存在: {local_zip_path}") + return None + + # 如果没有指定安装目录,使用默认目录 + if install_dir is None: + install_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "poppler") + + logger.info(f"开始从本地文件安装Poppler到: {install_dir}") + + # 创建安装目录 + if not os.path.exists(install_dir): + os.makedirs(install_dir) + + # 创建临时目录 + temp_dir = tempfile.mkdtemp() + + try: + # 解压文件 + logger.info("开始解压文件...") + with zipfile.ZipFile(local_zip_path, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + + # 移动文件到安装目录 + extracted_dir = os.path.join(temp_dir, os.path.splitext(os.path.basename(local_zip_path))[0]) + if not os.path.exists(extracted_dir): + # 尝试查找解压后的目录 + for item in os.listdir(temp_dir): + if item.startswith("poppler-"): + extracted_dir = os.path.join(temp_dir, item) + break + + if os.path.exists(extracted_dir): + for item in os.listdir(extracted_dir): + src = os.path.join(extracted_dir, item) + dst = os.path.join(install_dir, item) + if os.path.isdir(src): + shutil.copytree(src, dst, dirs_exist_ok=True) + else: + shutil.copy2(src, dst) + + logger.info(f"Poppler已安装到: {install_dir}") + return install_dir + else: + logger.error("无法找到解压后的Poppler目录") + return None + + finally: + # 清理临时文件 + shutil.rmtree(temp_dir) + + except Exception as e: + logger.error(f"从本地安装Poppler时出错: {str(e)}") + return None + +def setup_poppler_path(): + """配置Poppler路径""" + # 检查是否已经设置了POPPLER_PATH环境变量 + if 'POPPLER_PATH' in os.environ: + logger.info(f"使用已配置的Poppler路径: {os.environ['POPPLER_PATH']}") + return + + # 尝试常见安装路径 + possible_paths = [ + r"C:\Program Files\poppler\Library\bin", + r"C:\Program Files (x86)\poppler\Library\bin", + r"C:\poppler\Library\bin", + r"D:\Program Files\poppler\Library\bin", + r"D:\poppler\Library\bin", + os.path.join(os.path.dirname(os.path.abspath(__file__)), "poppler", "Library", "bin") + ] + + for path in possible_paths: + if os.path.exists(path): + os.environ["POPPLER_PATH"] = path + logger.info(f"找到并设置Poppler路径: {path}") + return + + # 尝试从本地安装 + local_zip_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "poppler.zip") + if os.path.exists(local_zip_path): + logger.info("找到本地Poppler安装包,开始安装...") + install_dir = install_poppler_from_local(local_zip_path) + + if install_dir: + bin_path = os.path.join(install_dir, "Library", "bin") + os.environ["POPPLER_PATH"] = bin_path + + # 添加到系统PATH + if bin_path not in os.environ.get("PATH", ""): + os.environ["PATH"] = bin_path + os.pathsep + os.environ.get("PATH", "") + + logger.info(f"Poppler已安装并配置到: {bin_path}") + return + + logger.warning("未找到Poppler路径,PDF转换可能失败。请确保安装了Poppler并设置正确的路径。") + logger.warning("警告: 未找到Poppler,这可能导致PDF转换失败") + +# 尝试设置Poppler路径 +setup_poppler_path() + +# 获取程序所在目录 +if getattr(sys, 'frozen', False): + # 打包环境 + BASE_DIR = getattr(sys, '_MEIPASS', os.path.dirname(sys.executable)) +else: + # 开发环境 + BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + +# 模型路径配置 +MODEL_DIR = resource_path(os.path.join(BASE_DIR, "model")) +YOLO_MODEL_PATH = resource_path(os.path.join(MODEL_DIR, "best.pt")) +LAYOUT_MODEL_DIR = resource_path(os.path.join(MODEL_DIR, "layout")) +OCR_CONFIG_PATH = resource_path(os.path.join(BASE_DIR, "ocr.yaml")) + +def pdf_to_images_bulk(input_folder, output_folder, dpi=300): + """ + 使用 pdf2image 高质量转换 PDF,每张图片按顺序编号 + + Args: + input_folder: PDF文件所在文件夹路径 + output_folder: 转换后图片保存的文件夹路径 + dpi: 分辨率,默认300 + + Returns: + 转换后的图片路径和对应的PDF来源信息列表 + """ + os.makedirs(output_folder, exist_ok=True) + pdf_files = [f for f in os.listdir(input_folder) if f.lower().endswith('.pdf')] + image_index = 1 # 统一编号 + image_info = [] # 存储图片信息 + + logger.info(f"开始处理 {len(pdf_files)} 个PDF文件...") + + for pdf_file in tqdm(sorted(pdf_files), desc="PDF转图片"): + pdf_path = os.path.join(input_folder, pdf_file) + + try: + # 检查环境变量中是否有Poppler路径 + poppler_path = os.environ.get("POPPLER_PATH") + if poppler_path and os.path.exists(poppler_path): + # 使用设置的Poppler路径 + logger.info(f"使用Poppler路径: {poppler_path}") + logger.info(f"处理PDF文件: {pdf_path}") + + try: + images = convert_from_path(pdf_path, dpi=dpi, poppler_path=poppler_path) + logger.info(f"成功转换PDF,共 {len(images)} 页") + except Exception as e: + logger.error(f"使用指定Poppler路径转换失败: {str(e)}") + # 尝试不指定路径 + logger.info("尝试不指定Poppler路径进行转换") + images = convert_from_path(pdf_path, dpi=dpi) + else: + # 尝试使用默认路径 + logger.warning("未找到Poppler路径,尝试使用系统默认路径") + images = convert_from_path(pdf_path, dpi=dpi) + logger.info("使用系统默认Poppler路径成功") + + for i, img in enumerate(images): + img_path = os.path.join(output_folder, f'{image_index}.png') + img.save(img_path, 'PNG') + + # 记录图片信息 + image_info.append({ + 'image_path': img_path, + 'pdf_file': pdf_file, + 'page_number': i+1, + 'image_index': image_index + }) + + image_index += 1 + except Exception as e: + logger.error(f"处理 {pdf_file} 时出错: {str(e)}") + # 打印更详细的错误信息 + import traceback + logger.error(f"详细错误: {traceback.format_exc()}") + + logger.info(f"PDF转图片完成,共转换 {len(image_info)} 张图片") + return image_info + +def classify_images(image_folder, confidence_threshold=0.6): + """ + 使用YOLO模型对图片进行分类 + + Args: + image_folder: 图片所在文件夹路径 + confidence_threshold: 置信度阈值 + + Returns: + 分类结果列表 + """ + # 检查缓存 + cache_key = f"classify_{image_folder}_{confidence_threshold}" + cached_result = get_from_cache(cache_key) + if cached_result is not None: + logger.info(f"使用缓存的分类结果,共 {len(cached_result)} 项") + return cached_result + + logger.info(f"加载YOLO模型: {YOLO_MODEL_PATH}") + model = YOLO(YOLO_MODEL_PATH) + + # 确保使用CPU推理 + model.to('cpu') + + image_files = [f for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] + results = [] + + logger.info(f"开始分类 {len(image_files)} 张图片...") + + for image_file in tqdm(sorted(image_files, key=lambda x: int(os.path.splitext(x)[0])), desc="图片分类"): + image_path = os.path.join(image_folder, image_file) + try: + result = model.predict(image_path, verbose=False)[0] + + class_id = int(result.probs.top1) + class_name = result.names[class_id] + confidence = float(result.probs.top1conf) + + # 根据类别确定旋转角度和是否为目标图片 + rotation_angle = 0 + if class_name == 'yes_90': + rotation_angle = -90 # 逆时针旋转90度 + elif class_name == 'yes_-90': + rotation_angle = 90 # 顺时针旋转90度 + + # 确定是否为目标图片(现在有三种yes类型都是目标) + is_target = (class_name in ['yes', 'yes_90', 'yes_-90']) and confidence >= confidence_threshold + + results.append({ + 'image_path': image_path, + 'image_name': image_file, + 'class_id': class_id, + 'class_name': class_name, + 'confidence': confidence, + 'rotation_angle': rotation_angle, + 'is_target': is_target + }) + + except Exception as e: + logger.error(f"处理图片 {image_file} 时出错: {str(e)}") + continue + + logger.info(f"图片分类完成,共分类 {len(results)} 张图片") + + # 保存到缓存 + save_to_cache(cache_key, results) + + return results + +def rotate_image(image_path, rotation_angle, output_folder=None): + """ + 根据指定角度旋转图片 + + Args: + image_path: 原始图片路径 + rotation_angle: 旋转角度(正值表示顺时针,负值表示逆时针) + output_folder: 输出文件夹,如果为None则使用原图所在文件夹 + + Returns: + 旋转后的图片路径 + """ + # 如果不需要旋转,直接返回原路径 + if rotation_angle == 0: + return image_path + + # 确定输出路径 + image_name = os.path.basename(image_path) + base_name, ext = os.path.splitext(image_name) + + if output_folder is None: + output_folder = os.path.dirname(image_path) + os.makedirs(output_folder, exist_ok=True) + + # 构建旋转后图片的文件名 + rotated_image_path = os.path.join(output_folder, f"{base_name}_rotated{ext}") + + # 读取图片 + img = cv2.imread(image_path) + if img is None: + logger.error(f"无法读取图片: {image_path}") + return image_path + + # 获取图片尺寸 + height, width = img.shape[:2] + + # 计算旋转矩阵 + if rotation_angle == 90 or rotation_angle == -90: + # 对于90度的旋转,交换宽高 + if rotation_angle == 90: # 顺时针90度 + rotated_img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + else: # 逆时针90度 + rotated_img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) + else: + # 其他角度的旋转 + center = (width // 2, height // 2) + rotation_matrix = cv2.getRotationMatrix2D(center, rotation_angle, 1.0) + + # 计算旋转后的图像大小 + abs_cos = abs(rotation_matrix[0, 0]) + abs_sin = abs(rotation_matrix[0, 1]) + bound_w = int(height * abs_sin + width * abs_cos) + bound_h = int(height * abs_cos + width * abs_sin) + + # 调整旋转矩阵 + rotation_matrix[0, 2] += bound_w / 2 - center[0] + rotation_matrix[1, 2] += bound_h / 2 - center[1] + + # 执行旋转 + rotated_img = cv2.warpAffine(img, rotation_matrix, (bound_w, bound_h), + flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, + borderValue=(255, 255, 255)) + + # 保存旋转后的图片 + cv2.imwrite(rotated_image_path, rotated_img) + logger.info(f"图片已旋转 {rotation_angle} 度并保存至: {rotated_image_path}") + + return rotated_image_path + +def detect_layout_and_ocr(image_path, layout_module, ocr_pipeline, output_dir=None): + """ + 对图片进行版面检测和OCR识别 + + Args: + image_path: 图片路径 + layout_module: 已加载的版面检测模型 + ocr_pipeline: 已加载的OCR产线 + output_dir: 输出目录,如果为None则自动创建 + + Returns: + 所有识别到的目标区域文本列表 + """ + # 检查缓存 + cache_key = f"layout_ocr_{image_path}" + cached_result = get_from_cache(cache_key) + if cached_result is not None and os.path.exists(cached_result.get('output_dir', '')): + logger.info(f"使用缓存的OCR结果: {os.path.basename(image_path)}") + return cached_result['region_texts'] + + # 创建输出目录 + if output_dir is None: + output_dir = os.path.join("results", "layout_ocr", os.path.splitext(os.path.basename(image_path))[0]) + os.makedirs(output_dir, exist_ok=True) + + # 获取图片文件名(不含扩展名) + img_name = os.path.splitext(os.path.basename(image_path))[0] + + # 执行版面检测 + results_generator = layout_module.predict(image_path, layout_nms=True) + + # 使用字典存储不同类别的区域 + all_regions = {'name': [], 'money': [], 'title': []} + target_regions = [] # 存储最终需要处理的区域(name和money) + result_count = 0 + + # 为每个结果创建目录 + for result_count, result in enumerate(results_generator, 1): + # 保存结果为JSON + json_path = os.path.join(output_dir, f"{img_name}_result_{result_count}.json") + result.save_to_json(json_path) + + # 从JSON读取检测结果 + try: + with open(json_path, 'r', encoding='utf-8') as f: + result_data = json.load(f) + + # 从result_data中提取目标区域 + # 直接从JSON根级别获取boxes字段 + if 'boxes' in result_data: + boxes = result_data['boxes'] + + # 遍历所有检测到的区域 + for box in boxes: + label = box.get('label') + score = box.get('score') + coordinate = box.get('coordinate') + + # 根据类别分类区域 + if label in all_regions: + all_regions[label].append({ + 'coordinate': coordinate, + 'score': score + }) + + # 如果是name或money区域,添加到目标区域列表 + if label in ['name', 'money']: + target_regions.append({ + 'label': label, + 'coordinate': coordinate, + 'score': score + }) + + # 记录统计 + for label, regions in all_regions.items(): + if regions: + logger.info(f"检测到 {len(regions)} 个 {label} 区域") + + except Exception as e: + logger.error(f"解析JSON结果时出错: {str(e)}") + + # 如果仍找不到目标区域,返回空列表 + if not target_regions: + logger.info(f"图片 {img_name} 未找到name或money区域") + return [] + + # 绘制和保存所有目标区域的可视化 + try: + img = cv2.imread(image_path) + if img is not None: + vis_img = img.copy() + + # 为不同类别设置不同颜色 + color_map = { + 'name': (0, 255, 0), # 绿色 + 'money': (0, 0, 255), # 红色 + 'title': (255, 0, 0) # 蓝色 + } + + # 绘制所有区域,包括title(用于参考) + for label, regions in all_regions.items(): + for idx, region_info in enumerate(regions): + coordinate = region_info['coordinate'] + x1, y1, x2, y2 = map(int, coordinate) + color = color_map.get(label, (128, 128, 128)) # 默认灰色 + cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2) + cv2.putText(vis_img, f"{label} {idx+1}", (x1, y1-10), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) + + # 保存结果图 + vis_img_path = os.path.join(output_dir, f"{img_name}_all_regions.jpg") + cv2.imwrite(vis_img_path, vis_img) + except Exception as e: + logger.error(f"生成区域可视化时出错: {str(e)}") + + # 对每个目标区域进行OCR识别 + all_region_texts = [] # 存储所有目标区域的文本内容 + + for idx, region_info in enumerate(target_regions): + region_label = region_info['label'] + bbox = region_info['coordinate'] + logger.info(f"处理{region_label}区域 {idx+1}...") + region_texts = [] # 存储当前区域的所有文本 + + try: + # 读取原图并裁剪区域 + img = cv2.imread(image_path) + if img is None: + continue + + img_height, img_width = img.shape[:2] + + # 根据不同区域类型设置不同的裁剪边距 + # 只缩小左右距离,保持上下不变 + if region_label == 'name': + # name区域裁剪参数(缩小左右边距,保持上下不变) + left_padding = 10 # 左侧缩小像素 + right_padding = 10 # 右侧缩小像素 + top_padding = 0 # 上侧保持不变 + bottom_padding = -7 # 下侧保持不变 + elif region_label == 'money': + # money区域裁剪参数(更大程度缩小左右边距,保持上下不变) + left_padding = 20 # 左侧缩小更多像素 + right_padding = 20 # 右侧缩小更多像素 + top_padding = 0 # 上侧保持不变 + bottom_padding = 0 # 下侧保持不变 + else: + # 其他区域使用默认参数 + left_padding = -7 + right_padding = -7 + top_padding = -7 + bottom_padding = -7 + + x1, y1, x2, y2 = map(int, bbox) + + # 应用裁剪边距 + x1 = max(0, x1 + left_padding) + y1 = max(0, y1 + top_padding) + x2 = min(img_width, x2 - right_padding) + y2 = min(img_height, y2 - bottom_padding) + + region_img = img[y1:y2, x1:x2].copy() + region_img_path = os.path.join(output_dir, f"{img_name}_{region_label}_{idx+1}.jpg") + cv2.imwrite(region_img_path, region_img) + + # 执行OCR识别 + ocr_results = ocr_pipeline.predict(region_img_path) + + # 处理OCR结果 + for res_idx, ocr_res in enumerate(ocr_results, 1): + # 保存JSON结果 + json_path = os.path.join(output_dir, f"{img_name}_{region_label}_{idx+1}_ocr_{res_idx}.json") + ocr_res.save_to_json(save_path=json_path) + + # 读取JSON提取文本 + try: + with open(json_path, 'r', encoding='utf-8') as f: + json_data = json.load(f) + + # 检查是否包含rec_texts字段(顶层) + if 'rec_texts' in json_data and isinstance(json_data['rec_texts'], list): + rec_texts = json_data['rec_texts'] + region_texts.extend(rec_texts) + # 检查列表中的元素 + elif isinstance(json_data, list): + for item in json_data: + if isinstance(item, dict): + if 'rec_texts' in item: + # 如果是rec_texts数组 + if isinstance(item['rec_texts'], list): + region_texts.extend(item['rec_texts']) + else: + region_texts.append(item['rec_texts']) + elif 'text' in item: + # 兼容旧格式 + region_texts.append(item['text']) + + # 如果JSON直接为空或不包含需要的字段,尝试其他方法 + if not region_texts and hasattr(ocr_res, 'rec_texts'): + region_texts.extend(ocr_res.rec_texts) + + except Exception as e: + logger.error(f"读取OCR JSON结果时出错: {str(e)}") + + # 尝试直接从对象提取文本信息 + try: + if hasattr(ocr_res, 'rec_texts'): + region_texts.extend(ocr_res.rec_texts) + except: + logger.error("无法从对象中直接提取文本") + + # 添加区域标签和文本到总列表 + all_region_texts.append({ + 'label': region_label, + 'texts': region_texts + }) + + logger.info(f"{region_label}区域{idx+1}提取到{len(region_texts)}个文本") + + # 保存当前区域文本 + txt_path = os.path.join(output_dir, f"{img_name}_{region_label}_{idx+1}_texts.txt") + with open(txt_path, 'w', encoding='utf-8') as f: + for text in region_texts: + f.write(f"{text}\n") + + except Exception as e: + logger.error(f"处理{region_label}区域{idx+1}时出错: {str(e)}") + continue + + # 保存结果到缓存 + cache_data = { + 'region_texts': all_region_texts, + 'output_dir': output_dir + } + save_to_cache(cache_key, cache_data) + + return all_region_texts + +def merge_region_texts(all_region_texts, output_file): + """ + 合并所有区域内容为一个txt文件,采用更易于阅读的格式 + + Args: + all_region_texts: 区域文本列表,每个元素包含label和texts字段 + output_file: 输出文件路径 + """ + with open(output_file, 'w', encoding='utf-8') as f: + f.write(f"区域识别结果 - 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write("=" * 80 + "\n\n") + + # 按类别分组输出 + for category in ['name', 'money']: + category_regions = [r for r in all_region_texts if r['label'] == category] + + if category_regions: + f.write(f"{category.upper()} 区域 (共{len(category_regions)}个):\n") + f.write("-" * 50 + "\n\n") + + for region_idx, region_data in enumerate(category_regions, 1): + texts = region_data['texts'] + f.write(f"{category} {region_idx} (共{len(texts)}项):\n") + + # 写入区域内容,每个条目带编号 + for text_idx, text in enumerate(texts, 1): + f.write(f" {text_idx}. {text}\n") + + # 如果不是最后一个区域,添加分隔符 + if region_idx < len(category_regions): + f.write("\n" + "-" * 30 + "\n\n") + + # 在类别之间添加分隔符 + f.write("\n" + "=" * 50 + "\n\n") + + logger.info(f"所有区域文本已合并保存至: {output_file}") + return output_file + +def post_process_money_texts(money_texts): + """ + 处理money文本列表,合并被分割的金额 + + Args: + money_texts: 金额文本列表 + + Returns: + 处理后的金额文本列表 + """ + if not money_texts: + return money_texts + + # 创建新的结果列表 + processed_texts = [] + skip_next = False + + for i in range(len(money_texts)): + if skip_next: + skip_next = False + continue + + current_text = money_texts[i].strip() + + # 检查当前文本是否以小数点结尾或包含小数点后缺少完整的分位 + if i < len(money_texts) - 1: + next_text = money_texts[i+1].strip() + needs_merge = False + + # 情况1: 当前金额以小数点结尾 + if current_text.endswith('.'): + needs_merge = True + # 情况2: 当前金额包含小数点但后面不足两位数字 + elif '.' in current_text: + decimal_part = current_text.split('.')[-1] + if len(decimal_part) < 2: + needs_merge = True + # 情况3: 下一个值是纯数字且长度小于等于2 + is_next_decimal = next_text.isdigit() and len(next_text) <= 2 + # 情况4: 下一个值以0开头且是纯数字(可能是小数部分) + is_next_leading_zero = next_text.startswith('0') and next_text.replace('0', '').isdigit() + # 情况5: 当前文本是纯数字,下一个文本以小数点开头 + is_current_integer = current_text.isdigit() and next_text.startswith('.') + + if (needs_merge and (is_next_decimal or is_next_leading_zero)) or is_current_integer: + # 合并当前金额和下一个金额 + merged_value = current_text + next_text + processed_texts.append(merged_value) + logger.info(f"合并金额: {current_text} + {next_text} = {merged_value}") + skip_next = True + continue + + # 如果没有合并,直接添加当前金额 + processed_texts.append(current_text) + + # 进行金额格式化 + formatted_texts = [] + for money in processed_texts: + # 尝试清理和标准化金额格式 + cleaned_money = money.strip() + + # 移除非数字、小数点和逗号 + cleaned_money = re.sub(r'[^\d.,]', '', cleaned_money) + + # 检查是否包含小数点,如果没有可能需要添加 + if '.' not in cleaned_money and cleaned_money.isdigit(): + cleaned_money = cleaned_money + '.00' + logger.info(f"格式化金额: {money} -> {cleaned_money}") + # 确保小数点后有两位数字 + elif '.' in cleaned_money: + parts = cleaned_money.split('.') + if len(parts) == 2: + integer_part, decimal_part = parts + # 如果小数部分长度不足2,补零 + if len(decimal_part) < 2: + decimal_part = decimal_part.ljust(2, '0') + # 如果小数部分超过2位,截断 + elif len(decimal_part) > 2: + decimal_part = decimal_part[:2] + cleaned_money = integer_part + '.' + decimal_part + logger.info(f"标准化金额: {money} -> {cleaned_money}") + + # 尝试移除可能被误识别的文本前缀或后缀 + # 查找第一个数字的位置 + first_digit_pos = -1 + for pos, char in enumerate(cleaned_money): + if char.isdigit() or char == '.': + first_digit_pos = pos + break + + if first_digit_pos > 0: + cleaned_money = cleaned_money[first_digit_pos:] + logger.info(f"移除金额前缀: {money} -> {cleaned_money}") + + # 只保留有效的金额 + try: + float_value = float(cleaned_money.replace(',', '')) + formatted_texts.append(cleaned_money) + except ValueError: + logger.warning(f"无效金额格式,已忽略: {money}") + + return formatted_texts + +def get_layout_model(): + """ + 懒加载版面检测模型,只在首次调用时加载 + + Returns: + 加载好的版面检测模型 + """ + global _loaded_layout_model + + if _loaded_layout_model is None: + logger.info("首次加载版面检测模型") + try: + _loaded_layout_model = pdx.create_model('PP-DocLayout-L', LAYOUT_MODEL_DIR) + logger.info(f"使用本地版面检测模型: {LAYOUT_MODEL_DIR}") + except Exception as e: + logger.error(f"加载本地版面检测模型失败: {str(e)}") + raise + + return _loaded_layout_model + +def get_ocr_pipeline(): + """ + 懒加载OCR模型,只在首次调用时加载 + + Returns: + 加载好的OCR产线 + """ + global _loaded_ocr_pipeline + + if _loaded_ocr_pipeline is None: + logger.info("首次加载OCR产线") + try: + _loaded_ocr_pipeline = pdx.create_pipeline(pipeline=OCR_CONFIG_PATH) + logger.info("OCR产线加载成功") + except Exception as e: + logger.error(f"创建OCR产线失败: {str(e)}") + raise + + return _loaded_ocr_pipeline + +def process_pdf_to_table(pdf_folder, output_folder, confidence_threshold=0.6): + """ + 端到端处理PDF到区域文本的整个流程 + + Args: + pdf_folder: PDF文件所在文件夹 + output_folder: 输出文件夹 + confidence_threshold: 置信度阈值 + + Returns: + 处理结果统计信息 + """ + # 创建输出目录 + os.makedirs(output_folder, exist_ok=True) + images_folder = os.path.join(output_folder, "images") + results_folder = os.path.join(output_folder, "results") + rotated_folder = os.path.join(output_folder, "rotated_images") + os.makedirs(images_folder, exist_ok=True) + os.makedirs(results_folder, exist_ok=True) + os.makedirs(rotated_folder, exist_ok=True) + + # 1. PDF转图片 + logger.info("第1步: PDF转图片") + image_info = pdf_to_images_bulk(pdf_folder, images_folder) + + # 2. 图片分类 + logger.info("第2步: 图片分类") + classification_results = classify_images(images_folder, confidence_threshold) + + # 3. 筛选目标图片 + target_images = [result for result in classification_results if result['is_target']] + logger.info(f"筛选出 {len(target_images)} 张目标图片") + + # 如果没有目标图片,提前返回结果 + if not target_images: + logger.warning("未找到目标图片,跳过后续处理") + return { + 'total_images': len(image_info), + 'target_images': 0, + 'rotated_images': 0, + 'successful_extractions': 0, + 'total_name_regions': 0, + 'total_money_regions': 0, + 'total_text_elements': 0 + } + + # 保存分类结果 + classification_csv = os.path.join(results_folder, "classification_results.csv") + pd.DataFrame(classification_results).to_csv(classification_csv, index=False, encoding='utf-8-sig') + + # 4. 图像旋转处理 + logger.info("第3步: 图像旋转处理") + for idx, image_info in enumerate(target_images): + if image_info['rotation_angle'] != 0: + original_path = image_info['image_path'] + rotated_path = rotate_image( + original_path, + image_info['rotation_angle'], + rotated_folder + ) + image_info['original_path'] = original_path + image_info['image_path'] = rotated_path + + # 懒加载模型 - 只在需要时才加载 + # 获取版面检测模型和OCR产线 + layout_module = get_layout_model() + ocr_pipeline = get_ocr_pipeline() + + # 5. 对每张目标图片进行版面检测和OCR识别 + logger.info("第4步: 版面检测和OCR识别") + processing_results = [] + + for idx, image_info in enumerate(tqdm(target_images, desc="处理目标图片")): + image_path = image_info['image_path'] + image_name = image_info['image_name'] + rotation_angle = image_info.get('rotation_angle', 0) + + # 创建图片特定的输出目录 + image_output_dir = os.path.join(results_folder, os.path.splitext(image_name)[0]) + + # 执行版面检测和OCR + region_texts = detect_layout_and_ocr(image_path, layout_module, ocr_pipeline, image_output_dir) + + # 如果找到目标区域,合并并保存结果 + if region_texts: + # 后处理money文本,合并被分割的金额 + for i, region in enumerate(region_texts): + if region['label'] == 'money': + region_texts[i]['texts'] = post_process_money_texts(region['texts']) + + merged_txt_path = os.path.join(results_folder, f"{os.path.splitext(image_name)[0]}_regions.txt") + merge_region_texts(region_texts, merged_txt_path) + + # 统计各类别区域数量 + name_regions = sum(1 for r in region_texts if r['label'] == 'name') + money_regions = sum(1 for r in region_texts if r['label'] == 'money') + total_texts = sum(len(r['texts']) for r in region_texts) + + processing_results.append({ + 'image_path': image_path, + 'image_name': image_name, + 'rotation_angle': rotation_angle, + 'name_regions': name_regions, + 'money_regions': money_regions, + 'total_regions': len(region_texts), + 'total_texts': total_texts, + 'output_file': merged_txt_path + }) + else: + logger.info(f"图片 {image_name} 未找到目标区域") + processing_results.append({ + 'image_path': image_path, + 'image_name': image_name, + 'rotation_angle': rotation_angle, + 'name_regions': 0, + 'money_regions': 0, + 'total_regions': 0, + 'total_texts': 0, + 'output_file': None + }) + + # 保存处理结果 + processing_csv = os.path.join(results_folder, "processing_results.csv") + pd.DataFrame(processing_results).to_csv(processing_csv, index=False, encoding='utf-8-sig') + + # 统计信息 + stats = { + 'total_images': len(image_info), + 'target_images': len(target_images), + 'rotated_images': sum(1 for img in target_images if img.get('rotation_angle', 0) != 0), + 'successful_extractions': sum(1 for r in processing_results if r['total_regions'] > 0), + 'total_name_regions': sum(r['name_regions'] for r in processing_results), + 'total_money_regions': sum(r['money_regions'] for r in processing_results), + 'total_text_elements': sum(r['total_texts'] for r in processing_results) + } + + return stats + +if __name__ == "__main__": + # 示例使用 + try: + stats = process_pdf_to_table( + pdf_folder="E:\ocr_cpu\pdf", # 输入PDF文件夹 + output_folder="output", # 输出文件夹 + confidence_threshold=0.6 # 置信度阈值 + ) + print("处理完成!") + print(f"目标图片数: {stats['target_images']}") + print(f"成功提取区域的图片数: {stats['successful_extractions']}") + except Exception as e: + logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) + sys.exit(1) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6b29ee6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,157 @@ +aiofiles==24.1.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.11.16 +aiosignal==1.3.2 +albucore==0.0.24 +albumentations==1.4.10 +annotated-types==0.7.0 +anyio==4.9.0 +astor==0.8.1 +async-timeout==4.0.3 +attrs==25.3.0 +backoff==2.2.1 +beautifulsoup4==4.13.3 +cachetools==5.5.2 +certifi==2025.1.31 +cffi==1.17.1 +chardet==5.2.0 +charset-normalizer==3.4.1 +chinese_calendar==1.10.0 +click==8.1.8 +colorama==0.4.6 +colorlog==6.9.0 +contourpy==1.3.0 +cryptography==44.0.2 +cssselect==1.3.0 +cssutils==2.11.1 +cycler==0.12.1 +dataclasses-json==0.6.7 +decorator==5.2.1 +decord==0.6.0 +distro==1.9.0 +emoji==2.14.1 +et_xmlfile==2.0.0 +eval_type_backport==0.2.2 +exceptiongroup==1.2.2 +faiss-cpu==1.8.0.post1 +filelock==3.18.0 +filetype==1.2.0 +fonttools==4.57.0 +frozenlist==1.5.0 +fsspec==2025.3.2 +ftfy==6.3.1 +GPUtil==1.4.0 +greenlet==3.1.1 +h11==0.14.0 +html5lib==1.1 +httpcore==1.0.7 +httpx==0.28.1 +huggingface-hub==0.30.2 +idna==3.10 +imageio==2.37.0 +imagesize==1.4.1 +importlib_resources==6.5.2 +jieba==0.42.1 +Jinja2==3.1.6 +jiter==0.9.0 +joblib==1.4.2 +jsonpatch==1.33 +jsonpointer==3.0.0 +kiwisolver==1.4.7 +langchain==0.2.17 +langchain-community==0.2.17 +langchain-core==0.2.43 +langchain-openai==0.1.25 +langchain-text-splitters==0.2.4 +langdetect==1.0.9 +langsmith==0.1.147 +lazy_loader==0.4 +lxml==5.3.2 +MarkupSafe==3.0.2 +marshmallow==3.26.1 +matplotlib==3.9.4 +more-itertools==10.6.0 +mpmath==1.3.0 +multidict==6.4.3 +mypy-extensions==1.0.0 +nest-asyncio==1.6.0 +networkx==3.2.1 +nltk==3.9.1 +numpy==1.24.4 +olefile==0.47 +openai==1.63.2 +openpyxl==3.1.5 +opt-einsum==3.3.0 +orjson==3.10.16 +packaging==24.2 +paddlepaddle==3.0.0rc0 +paddlex==3.0.0rc0 +pandas==2.2.3 +Parsley==1.3 +pdf2image==1.17.0 +pillow==11.2.1 +premailer==3.10.0 +prettytable==3.16.0 +propcache==0.3.1 +protobuf==6.30.2 +psutil==7.0.0 +py-cpuinfo==9.0.0 +pyclipper==1.3.0.post6 +pycocotools==2.0.8 +pycparser==2.22 +pydantic==2.11.3 +pydantic_core==2.33.1 +PyMuPDF==1.25.5 +pyparsing==3.2.3 +pypdf==5.4.0 +pyquaternion==0.9.9 +python-dateutil==2.9.0.post0 +python-iso639==2025.2.18 +python-magic==0.4.27 +python-oxmsg==0.0.2 +pytz==2025.2 +PyYAML==6.0.2 +RapidFuzz==3.13.0 +regex==2024.11.6 +requests==2.32.3 +requests-toolbelt==1.0.0 +ruamel.yaml==0.18.10 +ruamel.yaml.clib==0.2.12 +safetensors==0.5.3 +scikit-image==0.24.0 +scikit-learn==1.6.1 +scipy==1.13.1 +seaborn==0.13.2 +sentencepiece==0.2.0 +shapely==2.0.7 +simsimd==6.2.1 +six==1.17.0 +sniffio==1.3.1 +soundfile==0.13.1 +soupsieve==2.6 +SQLAlchemy==2.0.40 +stringzilla==3.12.4 +sympy==1.13.1 +tenacity==8.5.0 +threadpoolctl==3.6.0 +tifffile==2024.8.30 +tiktoken==0.9.0 +tokenizers==0.19.1 +torch==2.6.0 +torchvision==0.21.0 +tqdm==4.67.1 +transformers==4.40.0 +typing-inspect==0.9.0 +typing-inspection==0.4.0 +typing_extensions==4.13.0 +tzdata==2025.2 +ujson==5.10.0 +unstructured==0.17.2 +unstructured-client==0.32.3 +urllib3==2.4.0 +wcwidth==0.2.13 +webencodings==0.5.1 +wrapt==1.17.2 +yarl==1.19.0 +zipp==3.21.0 + diff --git a/requirements2.txt b/requirements2.txt new file mode 100644 index 0000000..de2413b --- /dev/null +++ b/requirements2.txt @@ -0,0 +1,5 @@ +opencv-contrib-python==4.10.0.84 +opencv-python==4.11.0.86 +opencv-python-headless==4.10.0.84 +ultralytics==8.3.108 +ultralytics-thop==2.0.14 \ No newline at end of file