payoff_OCR/pdf_to_table.py

985 lines
38 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import warnings
import json
import logging
from tqdm import tqdm
from pdf2image import convert_from_path
from ultralytics import YOLO
import pandas as pd
from datetime import datetime
import paddlex as pdx
import cv2
import numpy as np
import sys
from pathlib import Path
import requests
import zipfile
import shutil
import tempfile
import platform
import re
import hashlib
import pickle
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# 减少PaddlePaddle的日志输出
os.environ['FLAGS_call_stack_level'] = '2'
# 忽略警告信息
warnings.filterwarnings('ignore')
# 全局变量用于存储已加载的模型
_loaded_layout_model = None
_loaded_ocr_pipeline = None
# 缓存目录
CACHE_DIR = os.environ.get('OCR_CACHE_DIR', os.path.join(tempfile.gettempdir(), 'ocr_cache'))
# 是否启用缓存
USE_CACHE = os.environ.get('OCR_USE_CACHE', '1') == '1'
def get_cache_path(key):
"""获取缓存文件路径"""
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR, exist_ok=True)
# 使用MD5生成缓存文件名
hash_obj = hashlib.md5(key.encode('utf-8'))
return os.path.join(CACHE_DIR, f"{hash_obj.hexdigest()}.cache")
def get_from_cache(key):
"""从缓存获取数据"""
if not USE_CACHE:
return None
cache_path = get_cache_path(key)
if os.path.exists(cache_path):
try:
with open(cache_path, 'rb') as f:
return pickle.load(f)
except Exception as e:
logger.error(f"读取缓存失败: {str(e)}")
return None
def save_to_cache(key, data):
"""保存数据到缓存"""
if not USE_CACHE:
return
cache_path = get_cache_path(key)
try:
with open(cache_path, 'wb') as f:
pickle.dump(data, f)
except Exception as e:
logger.error(f"保存缓存失败: {str(e)}")
def resource_path(relative_path):
"""获取资源绝对路径支持开发环境和PyInstaller打包后的环境"""
try:
# PyInstaller creates a temp folder and stores path in _MEIPASS
base_path = getattr(sys, '_MEIPASS', os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
except Exception as e:
logger.error(f"获取资源路径时出错: {str(e)}")
return os.path.join(os.path.dirname(os.path.abspath(__file__)), relative_path)
def install_poppler_from_local(local_zip_path, install_dir=None):
"""从本地ZIP文件安装Poppler"""
try:
if not os.path.exists(local_zip_path):
logger.error(f"本地Poppler ZIP文件不存在: {local_zip_path}")
return None
# 如果没有指定安装目录,使用默认目录
if install_dir is None:
install_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "poppler")
logger.info(f"开始从本地文件安装Poppler到: {install_dir}")
# 创建安装目录
if not os.path.exists(install_dir):
os.makedirs(install_dir)
# 创建临时目录
temp_dir = tempfile.mkdtemp()
try:
# 解压文件
logger.info("开始解压文件...")
with zipfile.ZipFile(local_zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# 移动文件到安装目录
extracted_dir = os.path.join(temp_dir, os.path.splitext(os.path.basename(local_zip_path))[0])
if not os.path.exists(extracted_dir):
# 尝试查找解压后的目录
for item in os.listdir(temp_dir):
if item.startswith("poppler-"):
extracted_dir = os.path.join(temp_dir, item)
break
if os.path.exists(extracted_dir):
for item in os.listdir(extracted_dir):
src = os.path.join(extracted_dir, item)
dst = os.path.join(install_dir, item)
if os.path.isdir(src):
shutil.copytree(src, dst, dirs_exist_ok=True)
else:
shutil.copy2(src, dst)
logger.info(f"Poppler已安装到: {install_dir}")
return install_dir
else:
logger.error("无法找到解压后的Poppler目录")
return None
finally:
# 清理临时文件
shutil.rmtree(temp_dir)
except Exception as e:
logger.error(f"从本地安装Poppler时出错: {str(e)}")
return None
def setup_poppler_path():
"""配置Poppler路径"""
# 检查是否已经设置了POPPLER_PATH环境变量
if 'POPPLER_PATH' in os.environ:
logger.info(f"使用已配置的Poppler路径: {os.environ['POPPLER_PATH']}")
return
# 尝试常见安装路径
possible_paths = [
r"C:\Program Files\poppler\Library\bin",
r"C:\Program Files (x86)\poppler\Library\bin",
r"C:\poppler\Library\bin",
r"D:\Program Files\poppler\Library\bin",
r"D:\poppler\Library\bin",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "poppler", "Library", "bin")
]
for path in possible_paths:
if os.path.exists(path):
os.environ["POPPLER_PATH"] = path
logger.info(f"找到并设置Poppler路径: {path}")
return
# 尝试从本地安装
local_zip_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "poppler.zip")
if os.path.exists(local_zip_path):
logger.info("找到本地Poppler安装包开始安装...")
install_dir = install_poppler_from_local(local_zip_path)
if install_dir:
bin_path = os.path.join(install_dir, "Library", "bin")
os.environ["POPPLER_PATH"] = bin_path
# 添加到系统PATH
if bin_path not in os.environ.get("PATH", ""):
os.environ["PATH"] = bin_path + os.pathsep + os.environ.get("PATH", "")
logger.info(f"Poppler已安装并配置到: {bin_path}")
return
logger.warning("未找到Poppler路径PDF转换可能失败。请确保安装了Poppler并设置正确的路径。")
logger.warning("警告: 未找到Poppler这可能导致PDF转换失败")
# 尝试设置Poppler路径
setup_poppler_path()
# 获取程序所在目录
if getattr(sys, 'frozen', False):
# 打包环境
BASE_DIR = getattr(sys, '_MEIPASS', os.path.dirname(sys.executable))
else:
# 开发环境
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# 模型路径配置
MODEL_DIR = resource_path(os.path.join(BASE_DIR, "model"))
YOLO_MODEL_PATH = resource_path(os.path.join(MODEL_DIR, "best.pt"))
LAYOUT_MODEL_DIR = resource_path(os.path.join(MODEL_DIR, "layout"))
OCR_CONFIG_PATH = resource_path(os.path.join(BASE_DIR, "ocr.yaml"))
def pdf_to_images_bulk(input_folder, output_folder, dpi=300):
"""
使用 pdf2image 高质量转换 PDF每张图片按顺序编号
Args:
input_folder: PDF文件所在文件夹路径
output_folder: 转换后图片保存的文件夹路径
dpi: 分辨率默认300
Returns:
转换后的图片路径和对应的PDF来源信息列表
"""
os.makedirs(output_folder, exist_ok=True)
pdf_files = [f for f in os.listdir(input_folder) if f.lower().endswith('.pdf')]
image_index = 1 # 统一编号
image_info = [] # 存储图片信息
logger.info(f"开始处理 {len(pdf_files)} 个PDF文件...")
for pdf_file in tqdm(sorted(pdf_files), desc="PDF转图片"):
pdf_path = os.path.join(input_folder, pdf_file)
try:
# 检查环境变量中是否有Poppler路径
poppler_path = os.environ.get("POPPLER_PATH")
if poppler_path and os.path.exists(poppler_path):
# 使用设置的Poppler路径
logger.info(f"使用Poppler路径: {poppler_path}")
logger.info(f"处理PDF文件: {pdf_path}")
try:
images = convert_from_path(pdf_path, dpi=dpi, poppler_path=poppler_path)
logger.info(f"成功转换PDF{len(images)}")
except Exception as e:
logger.error(f"使用指定Poppler路径转换失败: {str(e)}")
# 尝试不指定路径
logger.info("尝试不指定Poppler路径进行转换")
images = convert_from_path(pdf_path, dpi=dpi)
else:
# 尝试使用默认路径
logger.warning("未找到Poppler路径尝试使用系统默认路径")
images = convert_from_path(pdf_path, dpi=dpi)
logger.info("使用系统默认Poppler路径成功")
for i, img in enumerate(images):
img_path = os.path.join(output_folder, f'{image_index}.png')
img.save(img_path, 'PNG')
# 记录图片信息
image_info.append({
'image_path': img_path,
'pdf_file': pdf_file,
'page_number': i+1,
'image_index': image_index
})
image_index += 1
except Exception as e:
logger.error(f"处理 {pdf_file} 时出错: {str(e)}")
# 打印更详细的错误信息
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
logger.info(f"PDF转图片完成共转换 {len(image_info)} 张图片")
return image_info
def classify_images(image_folder, confidence_threshold=0.6):
"""
使用YOLO模型对图片进行分类
Args:
image_folder: 图片所在文件夹路径
confidence_threshold: 置信度阈值
Returns:
分类结果列表
"""
# 检查缓存
cache_key = f"classify_{image_folder}_{confidence_threshold}"
cached_result = get_from_cache(cache_key)
if cached_result is not None:
logger.info(f"使用缓存的分类结果,共 {len(cached_result)}")
return cached_result
logger.info(f"加载YOLO模型: {YOLO_MODEL_PATH}")
model = YOLO(YOLO_MODEL_PATH)
# 确保使用CPU推理
model.to('cpu')
image_files = [f for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
results = []
logger.info(f"开始分类 {len(image_files)} 张图片...")
for image_file in tqdm(sorted(image_files, key=lambda x: int(os.path.splitext(x)[0])), desc="图片分类"):
image_path = os.path.join(image_folder, image_file)
try:
result = model.predict(image_path, verbose=False)[0]
class_id = int(result.probs.top1)
class_name = result.names[class_id]
confidence = float(result.probs.top1conf)
# 根据类别确定旋转角度和是否为目标图片
rotation_angle = 0
if class_name == 'yes_90':
rotation_angle = -90 # 逆时针旋转90度
elif class_name == 'yes_-90':
rotation_angle = 90 # 顺时针旋转90度
# 确定是否为目标图片现在有三种yes类型都是目标
is_target = (class_name in ['yes', 'yes_90', 'yes_-90']) and confidence >= confidence_threshold
results.append({
'image_path': image_path,
'image_name': image_file,
'class_id': class_id,
'class_name': class_name,
'confidence': confidence,
'rotation_angle': rotation_angle,
'is_target': is_target
})
except Exception as e:
logger.error(f"处理图片 {image_file} 时出错: {str(e)}")
continue
logger.info(f"图片分类完成,共分类 {len(results)} 张图片")
# 保存到缓存
save_to_cache(cache_key, results)
return results
def rotate_image(image_path, rotation_angle, output_folder=None):
"""
根据指定角度旋转图片
Args:
image_path: 原始图片路径
rotation_angle: 旋转角度(正值表示顺时针,负值表示逆时针)
output_folder: 输出文件夹如果为None则使用原图所在文件夹
Returns:
旋转后的图片路径
"""
# 如果不需要旋转,直接返回原路径
if rotation_angle == 0:
return image_path
# 确定输出路径
image_name = os.path.basename(image_path)
base_name, ext = os.path.splitext(image_name)
if output_folder is None:
output_folder = os.path.dirname(image_path)
os.makedirs(output_folder, exist_ok=True)
# 构建旋转后图片的文件名
rotated_image_path = os.path.join(output_folder, f"{base_name}_rotated{ext}")
# 读取图片
img = cv2.imread(image_path)
if img is None:
logger.error(f"无法读取图片: {image_path}")
return image_path
# 获取图片尺寸
height, width = img.shape[:2]
# 计算旋转矩阵
if rotation_angle == 90 or rotation_angle == -90:
# 对于90度的旋转交换宽高
if rotation_angle == 90: # 顺时针90度
rotated_img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
else: # 逆时针90度
rotated_img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
# 其他角度的旋转
center = (width // 2, height // 2)
rotation_matrix = cv2.getRotationMatrix2D(center, rotation_angle, 1.0)
# 计算旋转后的图像大小
abs_cos = abs(rotation_matrix[0, 0])
abs_sin = abs(rotation_matrix[0, 1])
bound_w = int(height * abs_sin + width * abs_cos)
bound_h = int(height * abs_cos + width * abs_sin)
# 调整旋转矩阵
rotation_matrix[0, 2] += bound_w / 2 - center[0]
rotation_matrix[1, 2] += bound_h / 2 - center[1]
# 执行旋转
rotated_img = cv2.warpAffine(img, rotation_matrix, (bound_w, bound_h),
flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT,
borderValue=(255, 255, 255))
# 保存旋转后的图片
cv2.imwrite(rotated_image_path, rotated_img)
logger.info(f"图片已旋转 {rotation_angle} 度并保存至: {rotated_image_path}")
return rotated_image_path
def detect_layout_and_ocr(image_path, layout_module, ocr_pipeline, output_dir=None):
"""
对图片进行版面检测和OCR识别
Args:
image_path: 图片路径
layout_module: 已加载的版面检测模型
ocr_pipeline: 已加载的OCR产线
output_dir: 输出目录如果为None则自动创建
Returns:
所有识别到的目标区域文本列表
"""
# 检查缓存
cache_key = f"layout_ocr_{image_path}"
cached_result = get_from_cache(cache_key)
if cached_result is not None and os.path.exists(cached_result.get('output_dir', '')):
logger.info(f"使用缓存的OCR结果: {os.path.basename(image_path)}")
return cached_result['region_texts']
# 创建输出目录
if output_dir is None:
output_dir = os.path.join("results", "layout_ocr", os.path.splitext(os.path.basename(image_path))[0])
os.makedirs(output_dir, exist_ok=True)
# 获取图片文件名(不含扩展名)
img_name = os.path.splitext(os.path.basename(image_path))[0]
# 执行版面检测
results_generator = layout_module.predict(image_path, layout_nms=True)
# 使用字典存储不同类别的区域
all_regions = {'name': [], 'money': [], 'title': []}
target_regions = [] # 存储最终需要处理的区域name和money
result_count = 0
# 为每个结果创建目录
for result_count, result in enumerate(results_generator, 1):
# 保存结果为JSON
json_path = os.path.join(output_dir, f"{img_name}_result_{result_count}.json")
result.save_to_json(json_path)
# 从JSON读取检测结果
try:
with open(json_path, 'r', encoding='utf-8') as f:
result_data = json.load(f)
# 从result_data中提取目标区域
# 直接从JSON根级别获取boxes字段
if 'boxes' in result_data:
boxes = result_data['boxes']
# 遍历所有检测到的区域
for box in boxes:
label = box.get('label')
score = box.get('score')
coordinate = box.get('coordinate')
# 根据类别分类区域
if label in all_regions:
all_regions[label].append({
'coordinate': coordinate,
'score': score
})
# 如果是name或money区域添加到目标区域列表
if label in ['name', 'money']:
target_regions.append({
'label': label,
'coordinate': coordinate,
'score': score
})
# 记录统计
for label, regions in all_regions.items():
if regions:
logger.info(f"检测到 {len(regions)}{label} 区域")
except Exception as e:
logger.error(f"解析JSON结果时出错: {str(e)}")
# 如果仍找不到目标区域,返回空列表
if not target_regions:
logger.info(f"图片 {img_name} 未找到name或money区域")
return []
# 绘制和保存所有目标区域的可视化
try:
img = cv2.imread(image_path)
if img is not None:
vis_img = img.copy()
# 为不同类别设置不同颜色
color_map = {
'name': (0, 255, 0), # 绿色
'money': (0, 0, 255), # 红色
'title': (255, 0, 0) # 蓝色
}
# 绘制所有区域包括title用于参考
for label, regions in all_regions.items():
for idx, region_info in enumerate(regions):
coordinate = region_info['coordinate']
x1, y1, x2, y2 = map(int, coordinate)
color = color_map.get(label, (128, 128, 128)) # 默认灰色
cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
cv2.putText(vis_img, f"{label} {idx+1}", (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# 保存结果图
vis_img_path = os.path.join(output_dir, f"{img_name}_all_regions.jpg")
cv2.imwrite(vis_img_path, vis_img)
except Exception as e:
logger.error(f"生成区域可视化时出错: {str(e)}")
# 对每个目标区域进行OCR识别
all_region_texts = [] # 存储所有目标区域的文本内容
for idx, region_info in enumerate(target_regions):
region_label = region_info['label']
bbox = region_info['coordinate']
logger.info(f"处理{region_label}区域 {idx+1}...")
region_texts = [] # 存储当前区域的所有文本
try:
# 读取原图并裁剪区域
img = cv2.imread(image_path)
if img is None:
continue
img_height, img_width = img.shape[:2]
# 根据不同区域类型设置不同的裁剪边距
# 只缩小左右距离,保持上下不变
if region_label == 'name':
# name区域裁剪参数缩小左右边距保持上下不变
left_padding = 10 # 左侧缩小像素
right_padding = 10 # 右侧缩小像素
top_padding = 0 # 上侧保持不变
bottom_padding = -7 # 下侧保持不变
elif region_label == 'money':
# money区域裁剪参数更大程度缩小左右边距保持上下不变
left_padding = 20 # 左侧缩小更多像素
right_padding = 20 # 右侧缩小更多像素
top_padding = 0 # 上侧保持不变
bottom_padding = 0 # 下侧保持不变
else:
# 其他区域使用默认参数
left_padding = -7
right_padding = -7
top_padding = -7
bottom_padding = -7
x1, y1, x2, y2 = map(int, bbox)
# 应用裁剪边距
x1 = max(0, x1 + left_padding)
y1 = max(0, y1 + top_padding)
x2 = min(img_width, x2 - right_padding)
y2 = min(img_height, y2 - bottom_padding)
region_img = img[y1:y2, x1:x2].copy()
region_img_path = os.path.join(output_dir, f"{img_name}_{region_label}_{idx+1}.jpg")
cv2.imwrite(region_img_path, region_img)
# 执行OCR识别
ocr_results = ocr_pipeline.predict(region_img_path)
# 处理OCR结果
for res_idx, ocr_res in enumerate(ocr_results, 1):
# 保存JSON结果
json_path = os.path.join(output_dir, f"{img_name}_{region_label}_{idx+1}_ocr_{res_idx}.json")
ocr_res.save_to_json(save_path=json_path)
# 读取JSON提取文本
try:
with open(json_path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
# 检查是否包含rec_texts字段(顶层)
if 'rec_texts' in json_data and isinstance(json_data['rec_texts'], list):
rec_texts = json_data['rec_texts']
region_texts.extend(rec_texts)
# 检查列表中的元素
elif isinstance(json_data, list):
for item in json_data:
if isinstance(item, dict):
if 'rec_texts' in item:
# 如果是rec_texts数组
if isinstance(item['rec_texts'], list):
region_texts.extend(item['rec_texts'])
else:
region_texts.append(item['rec_texts'])
elif 'text' in item:
# 兼容旧格式
region_texts.append(item['text'])
# 如果JSON直接为空或不包含需要的字段尝试其他方法
if not region_texts and hasattr(ocr_res, 'rec_texts'):
region_texts.extend(ocr_res.rec_texts)
except Exception as e:
logger.error(f"读取OCR JSON结果时出错: {str(e)}")
# 尝试直接从对象提取文本信息
try:
if hasattr(ocr_res, 'rec_texts'):
region_texts.extend(ocr_res.rec_texts)
except:
logger.error("无法从对象中直接提取文本")
# 添加区域标签和文本到总列表
all_region_texts.append({
'label': region_label,
'texts': region_texts
})
logger.info(f"{region_label}区域{idx+1}提取到{len(region_texts)}个文本")
# 保存当前区域文本
txt_path = os.path.join(output_dir, f"{img_name}_{region_label}_{idx+1}_texts.txt")
with open(txt_path, 'w', encoding='utf-8') as f:
for text in region_texts:
f.write(f"{text}\n")
except Exception as e:
logger.error(f"处理{region_label}区域{idx+1}时出错: {str(e)}")
continue
# 保存结果到缓存
cache_data = {
'region_texts': all_region_texts,
'output_dir': output_dir
}
save_to_cache(cache_key, cache_data)
return all_region_texts
def merge_region_texts(all_region_texts, output_file):
"""
合并所有区域内容为一个txt文件采用更易于阅读的格式
Args:
all_region_texts: 区域文本列表每个元素包含label和texts字段
output_file: 输出文件路径
"""
with open(output_file, 'w', encoding='utf-8') as f:
f.write(f"区域识别结果 - 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write("=" * 80 + "\n\n")
# 按类别分组输出
for category in ['name', 'money']:
category_regions = [r for r in all_region_texts if r['label'] == category]
if category_regions:
f.write(f"{category.upper()} 区域 (共{len(category_regions)}个):\n")
f.write("-" * 50 + "\n\n")
for region_idx, region_data in enumerate(category_regions, 1):
texts = region_data['texts']
f.write(f"{category} {region_idx} (共{len(texts)}项):\n")
# 写入区域内容,每个条目带编号
for text_idx, text in enumerate(texts, 1):
f.write(f" {text_idx}. {text}\n")
# 如果不是最后一个区域,添加分隔符
if region_idx < len(category_regions):
f.write("\n" + "-" * 30 + "\n\n")
# 在类别之间添加分隔符
f.write("\n" + "=" * 50 + "\n\n")
logger.info(f"所有区域文本已合并保存至: {output_file}")
return output_file
def post_process_money_texts(money_texts):
"""
处理money文本列表合并被分割的金额
Args:
money_texts: 金额文本列表
Returns:
处理后的金额文本列表
"""
if not money_texts:
return money_texts
# 创建新的结果列表
processed_texts = []
skip_next = False
for i in range(len(money_texts)):
if skip_next:
skip_next = False
continue
current_text = money_texts[i].strip()
# 检查当前文本是否以小数点结尾或包含小数点后缺少完整的分位
if i < len(money_texts) - 1:
next_text = money_texts[i+1].strip()
needs_merge = False
# 情况1: 当前金额以小数点结尾
if current_text.endswith('.'):
needs_merge = True
# 情况2: 当前金额包含小数点但后面不足两位数字
elif '.' in current_text:
decimal_part = current_text.split('.')[-1]
if len(decimal_part) < 2:
needs_merge = True
# 情况3: 下一个值是纯数字且长度小于等于2
is_next_decimal = next_text.isdigit() and len(next_text) <= 2
# 情况4: 下一个值以0开头且是纯数字(可能是小数部分)
is_next_leading_zero = next_text.startswith('0') and next_text.replace('0', '').isdigit()
# 情况5: 当前文本是纯数字,下一个文本以小数点开头
is_current_integer = current_text.isdigit() and next_text.startswith('.')
if (needs_merge and (is_next_decimal or is_next_leading_zero)) or is_current_integer:
# 合并当前金额和下一个金额
merged_value = current_text + next_text
processed_texts.append(merged_value)
logger.info(f"合并金额: {current_text} + {next_text} = {merged_value}")
skip_next = True
continue
# 如果没有合并,直接添加当前金额
processed_texts.append(current_text)
# 进行金额格式化
formatted_texts = []
for money in processed_texts:
# 尝试清理和标准化金额格式
cleaned_money = money.strip()
# 移除非数字、小数点和逗号
cleaned_money = re.sub(r'[^\d.,]', '', cleaned_money)
# 检查是否包含小数点,如果没有可能需要添加
if '.' not in cleaned_money and cleaned_money.isdigit():
cleaned_money = cleaned_money + '.00'
logger.info(f"格式化金额: {money} -> {cleaned_money}")
# 确保小数点后有两位数字
elif '.' in cleaned_money:
parts = cleaned_money.split('.')
if len(parts) == 2:
integer_part, decimal_part = parts
# 如果小数部分长度不足2补零
if len(decimal_part) < 2:
decimal_part = decimal_part.ljust(2, '0')
# 如果小数部分超过2位截断
elif len(decimal_part) > 2:
decimal_part = decimal_part[:2]
cleaned_money = integer_part + '.' + decimal_part
logger.info(f"标准化金额: {money} -> {cleaned_money}")
# 尝试移除可能被误识别的文本前缀或后缀
# 查找第一个数字的位置
first_digit_pos = -1
for pos, char in enumerate(cleaned_money):
if char.isdigit() or char == '.':
first_digit_pos = pos
break
if first_digit_pos > 0:
cleaned_money = cleaned_money[first_digit_pos:]
logger.info(f"移除金额前缀: {money} -> {cleaned_money}")
# 只保留有效的金额
try:
float_value = float(cleaned_money.replace(',', ''))
formatted_texts.append(cleaned_money)
except ValueError:
logger.warning(f"无效金额格式,已忽略: {money}")
return formatted_texts
def get_layout_model():
"""
懒加载版面检测模型,只在首次调用时加载
Returns:
加载好的版面检测模型
"""
global _loaded_layout_model
if _loaded_layout_model is None:
logger.info("首次加载版面检测模型")
try:
_loaded_layout_model = pdx.create_model('PP-DocLayout-L', LAYOUT_MODEL_DIR)
logger.info(f"使用本地版面检测模型: {LAYOUT_MODEL_DIR}")
except Exception as e:
logger.error(f"加载本地版面检测模型失败: {str(e)}")
raise
return _loaded_layout_model
def get_ocr_pipeline():
"""
懒加载OCR模型只在首次调用时加载
Returns:
加载好的OCR产线
"""
global _loaded_ocr_pipeline
if _loaded_ocr_pipeline is None:
logger.info("首次加载OCR产线")
try:
_loaded_ocr_pipeline = pdx.create_pipeline(pipeline=OCR_CONFIG_PATH)
logger.info("OCR产线加载成功")
except Exception as e:
logger.error(f"创建OCR产线失败: {str(e)}")
raise
return _loaded_ocr_pipeline
def process_pdf_to_table(pdf_folder, output_folder, confidence_threshold=0.6):
"""
端到端处理PDF到区域文本的整个流程
Args:
pdf_folder: PDF文件所在文件夹
output_folder: 输出文件夹
confidence_threshold: 置信度阈值
Returns:
处理结果统计信息
"""
# 创建输出目录
os.makedirs(output_folder, exist_ok=True)
images_folder = os.path.join(output_folder, "images")
results_folder = os.path.join(output_folder, "results")
rotated_folder = os.path.join(output_folder, "rotated_images")
os.makedirs(images_folder, exist_ok=True)
os.makedirs(results_folder, exist_ok=True)
os.makedirs(rotated_folder, exist_ok=True)
# 1. PDF转图片
logger.info("第1步: PDF转图片")
image_info = pdf_to_images_bulk(pdf_folder, images_folder)
# 2. 图片分类
logger.info("第2步: 图片分类")
classification_results = classify_images(images_folder, confidence_threshold)
# 3. 筛选目标图片
target_images = [result for result in classification_results if result['is_target']]
logger.info(f"筛选出 {len(target_images)} 张目标图片")
# 如果没有目标图片,提前返回结果
if not target_images:
logger.warning("未找到目标图片,跳过后续处理")
return {
'total_images': len(image_info),
'target_images': 0,
'rotated_images': 0,
'successful_extractions': 0,
'total_name_regions': 0,
'total_money_regions': 0,
'total_text_elements': 0
}
# 保存分类结果
classification_csv = os.path.join(results_folder, "classification_results.csv")
pd.DataFrame(classification_results).to_csv(classification_csv, index=False, encoding='utf-8-sig')
# 4. 图像旋转处理
logger.info("第3步: 图像旋转处理")
for idx, image_info in enumerate(target_images):
if image_info['rotation_angle'] != 0:
original_path = image_info['image_path']
rotated_path = rotate_image(
original_path,
image_info['rotation_angle'],
rotated_folder
)
image_info['original_path'] = original_path
image_info['image_path'] = rotated_path
# 懒加载模型 - 只在需要时才加载
# 获取版面检测模型和OCR产线
layout_module = get_layout_model()
ocr_pipeline = get_ocr_pipeline()
# 5. 对每张目标图片进行版面检测和OCR识别
logger.info("第4步: 版面检测和OCR识别")
processing_results = []
for idx, image_info in enumerate(tqdm(target_images, desc="处理目标图片")):
image_path = image_info['image_path']
image_name = image_info['image_name']
rotation_angle = image_info.get('rotation_angle', 0)
# 创建图片特定的输出目录
image_output_dir = os.path.join(results_folder, os.path.splitext(image_name)[0])
# 执行版面检测和OCR
region_texts = detect_layout_and_ocr(image_path, layout_module, ocr_pipeline, image_output_dir)
# 如果找到目标区域,合并并保存结果
if region_texts:
# 后处理money文本合并被分割的金额
for i, region in enumerate(region_texts):
if region['label'] == 'money':
region_texts[i]['texts'] = post_process_money_texts(region['texts'])
merged_txt_path = os.path.join(results_folder, f"{os.path.splitext(image_name)[0]}_regions.txt")
merge_region_texts(region_texts, merged_txt_path)
# 统计各类别区域数量
name_regions = sum(1 for r in region_texts if r['label'] == 'name')
money_regions = sum(1 for r in region_texts if r['label'] == 'money')
total_texts = sum(len(r['texts']) for r in region_texts)
processing_results.append({
'image_path': image_path,
'image_name': image_name,
'rotation_angle': rotation_angle,
'name_regions': name_regions,
'money_regions': money_regions,
'total_regions': len(region_texts),
'total_texts': total_texts,
'output_file': merged_txt_path
})
else:
logger.info(f"图片 {image_name} 未找到目标区域")
processing_results.append({
'image_path': image_path,
'image_name': image_name,
'rotation_angle': rotation_angle,
'name_regions': 0,
'money_regions': 0,
'total_regions': 0,
'total_texts': 0,
'output_file': None
})
# 保存处理结果
processing_csv = os.path.join(results_folder, "processing_results.csv")
pd.DataFrame(processing_results).to_csv(processing_csv, index=False, encoding='utf-8-sig')
# 统计信息
stats = {
'total_images': len(image_info),
'target_images': len(target_images),
'rotated_images': sum(1 for img in target_images if img.get('rotation_angle', 0) != 0),
'successful_extractions': sum(1 for r in processing_results if r['total_regions'] > 0),
'total_name_regions': sum(r['name_regions'] for r in processing_results),
'total_money_regions': sum(r['money_regions'] for r in processing_results),
'total_text_elements': sum(r['total_texts'] for r in processing_results)
}
return stats
if __name__ == "__main__":
# 示例使用
try:
stats = process_pdf_to_table(
pdf_folder="E:\ocr_cpu\pdf", # 输入PDF文件夹
output_folder="output", # 输出文件夹
confidence_threshold=0.6 # 置信度阈值
)
print("处理完成!")
print(f"目标图片数: {stats['target_images']}")
print(f"成功提取区域的图片数: {stats['successful_extractions']}")
except Exception as e:
logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
sys.exit(1)