890 lines
36 KiB
Python
890 lines
36 KiB
Python
import os
|
||
import warnings
|
||
import json
|
||
import logging
|
||
from tqdm import tqdm
|
||
from pdf2image import convert_from_path
|
||
from ultralytics import YOLO
|
||
import pandas as pd
|
||
from datetime import datetime
|
||
import paddlex as pdx
|
||
import cv2
|
||
import numpy as np
|
||
import sys
|
||
from pathlib import Path
|
||
import requests
|
||
import zipfile
|
||
import shutil
|
||
import tempfile
|
||
import platform
|
||
import re
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
datefmt='%Y-%m-%d %H:%M:%S'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 减少PaddlePaddle的日志输出
|
||
os.environ['FLAGS_call_stack_level'] = '2'
|
||
|
||
# 忽略警告信息
|
||
warnings.filterwarnings('ignore')
|
||
|
||
def resource_path(relative_path):
|
||
"""获取资源绝对路径,支持开发环境和PyInstaller打包后的环境"""
|
||
try:
|
||
# PyInstaller creates a temp folder and stores path in _MEIPASS
|
||
base_path = getattr(sys, '_MEIPASS', os.path.dirname(os.path.abspath(__file__)))
|
||
return os.path.join(base_path, relative_path)
|
||
except Exception as e:
|
||
logger.error(f"获取资源路径时出错: {str(e)}")
|
||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), relative_path)
|
||
|
||
def install_poppler_from_local(local_zip_path, install_dir=None):
|
||
"""从本地ZIP文件安装Poppler"""
|
||
try:
|
||
if not os.path.exists(local_zip_path):
|
||
logger.error(f"本地Poppler ZIP文件不存在: {local_zip_path}")
|
||
return None
|
||
|
||
# 如果没有指定安装目录,使用默认目录
|
||
if install_dir is None:
|
||
install_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "poppler")
|
||
|
||
logger.info(f"开始从本地文件安装Poppler到: {install_dir}")
|
||
|
||
# 创建安装目录
|
||
if not os.path.exists(install_dir):
|
||
os.makedirs(install_dir)
|
||
|
||
# 创建临时目录
|
||
temp_dir = tempfile.mkdtemp()
|
||
|
||
try:
|
||
# 解压文件
|
||
logger.info("开始解压文件...")
|
||
with zipfile.ZipFile(local_zip_path, 'r') as zip_ref:
|
||
zip_ref.extractall(temp_dir)
|
||
|
||
# 移动文件到安装目录
|
||
extracted_dir = os.path.join(temp_dir, os.path.splitext(os.path.basename(local_zip_path))[0])
|
||
if not os.path.exists(extracted_dir):
|
||
# 尝试查找解压后的目录
|
||
for item in os.listdir(temp_dir):
|
||
if item.startswith("poppler-"):
|
||
extracted_dir = os.path.join(temp_dir, item)
|
||
break
|
||
|
||
if os.path.exists(extracted_dir):
|
||
for item in os.listdir(extracted_dir):
|
||
src = os.path.join(extracted_dir, item)
|
||
dst = os.path.join(install_dir, item)
|
||
if os.path.isdir(src):
|
||
shutil.copytree(src, dst, dirs_exist_ok=True)
|
||
else:
|
||
shutil.copy2(src, dst)
|
||
|
||
logger.info(f"Poppler已安装到: {install_dir}")
|
||
return install_dir
|
||
else:
|
||
logger.error("无法找到解压后的Poppler目录")
|
||
return None
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
shutil.rmtree(temp_dir)
|
||
|
||
except Exception as e:
|
||
logger.error(f"从本地安装Poppler时出错: {str(e)}")
|
||
return None
|
||
|
||
def setup_poppler_path():
|
||
"""配置Poppler路径"""
|
||
# 检查是否已经设置了POPPLER_PATH环境变量
|
||
if 'POPPLER_PATH' in os.environ:
|
||
logger.info(f"使用已配置的Poppler路径: {os.environ['POPPLER_PATH']}")
|
||
return
|
||
|
||
# 尝试常见安装路径
|
||
possible_paths = [
|
||
r"C:\Program Files\poppler\Library\bin",
|
||
r"C:\Program Files (x86)\poppler\Library\bin",
|
||
r"C:\poppler\Library\bin",
|
||
r"D:\Program Files\poppler\Library\bin",
|
||
r"D:\poppler\Library\bin",
|
||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "poppler", "Library", "bin")
|
||
]
|
||
|
||
for path in possible_paths:
|
||
if os.path.exists(path):
|
||
os.environ["POPPLER_PATH"] = path
|
||
logger.info(f"找到并设置Poppler路径: {path}")
|
||
return
|
||
|
||
# 尝试从本地安装
|
||
local_zip_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "poppler.zip")
|
||
if os.path.exists(local_zip_path):
|
||
logger.info("找到本地Poppler安装包,开始安装...")
|
||
install_dir = install_poppler_from_local(local_zip_path)
|
||
|
||
if install_dir:
|
||
bin_path = os.path.join(install_dir, "Library", "bin")
|
||
os.environ["POPPLER_PATH"] = bin_path
|
||
|
||
# 添加到系统PATH
|
||
if bin_path not in os.environ.get("PATH", ""):
|
||
os.environ["PATH"] = bin_path + os.pathsep + os.environ.get("PATH", "")
|
||
|
||
logger.info(f"Poppler已安装并配置到: {bin_path}")
|
||
return
|
||
|
||
logger.warning("未找到Poppler路径,PDF转换可能失败。请确保安装了Poppler并设置正确的路径。")
|
||
logger.warning("警告: 未找到Poppler,这可能导致PDF转换失败")
|
||
|
||
# 尝试设置Poppler路径
|
||
setup_poppler_path()
|
||
|
||
# 获取程序所在目录
|
||
if getattr(sys, 'frozen', False):
|
||
# 打包环境
|
||
BASE_DIR = getattr(sys, '_MEIPASS', os.path.dirname(sys.executable))
|
||
else:
|
||
# 开发环境
|
||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
# 模型路径配置
|
||
MODEL_DIR = resource_path(os.path.join(BASE_DIR, "model"))
|
||
YOLO_MODEL_PATH = resource_path(os.path.join(MODEL_DIR, "best.pt"))
|
||
LAYOUT_MODEL_DIR = resource_path(os.path.join(MODEL_DIR, "layout"))
|
||
OCR_CONFIG_PATH = resource_path(os.path.join(BASE_DIR, "ocr.yaml"))
|
||
|
||
def pdf_to_images_bulk(input_folder, output_folder, dpi=300):
|
||
"""
|
||
使用 pdf2image 高质量转换 PDF,每张图片按顺序编号
|
||
|
||
Args:
|
||
input_folder: PDF文件所在文件夹路径
|
||
output_folder: 转换后图片保存的文件夹路径
|
||
dpi: 分辨率,默认300
|
||
|
||
Returns:
|
||
转换后的图片路径和对应的PDF来源信息列表
|
||
"""
|
||
os.makedirs(output_folder, exist_ok=True)
|
||
pdf_files = [f for f in os.listdir(input_folder) if f.lower().endswith('.pdf')]
|
||
image_index = 1 # 统一编号
|
||
image_info = [] # 存储图片信息
|
||
|
||
logger.info(f"开始处理 {len(pdf_files)} 个PDF文件...")
|
||
|
||
for pdf_file in tqdm(sorted(pdf_files), desc="PDF转图片"):
|
||
pdf_path = os.path.join(input_folder, pdf_file)
|
||
|
||
try:
|
||
# 检查环境变量中是否有Poppler路径
|
||
poppler_path = os.environ.get("POPPLER_PATH")
|
||
if poppler_path and os.path.exists(poppler_path):
|
||
# 使用设置的Poppler路径
|
||
logger.info(f"使用Poppler路径: {poppler_path}")
|
||
logger.info(f"处理PDF文件: {pdf_path}")
|
||
|
||
try:
|
||
images = convert_from_path(pdf_path, dpi=dpi, poppler_path=poppler_path)
|
||
logger.info(f"成功转换PDF,共 {len(images)} 页")
|
||
except Exception as e:
|
||
logger.error(f"使用指定Poppler路径转换失败: {str(e)}")
|
||
# 尝试不指定路径
|
||
logger.info("尝试不指定Poppler路径进行转换")
|
||
images = convert_from_path(pdf_path, dpi=dpi)
|
||
else:
|
||
# 尝试使用默认路径
|
||
logger.warning("未找到Poppler路径,尝试使用系统默认路径")
|
||
images = convert_from_path(pdf_path, dpi=dpi)
|
||
logger.info("使用系统默认Poppler路径成功")
|
||
|
||
for i, img in enumerate(images):
|
||
img_path = os.path.join(output_folder, f'{image_index}.png')
|
||
img.save(img_path, 'PNG')
|
||
|
||
# 记录图片信息
|
||
image_info.append({
|
||
'image_path': img_path,
|
||
'pdf_file': pdf_file,
|
||
'page_number': i+1,
|
||
'image_index': image_index
|
||
})
|
||
|
||
image_index += 1
|
||
except Exception as e:
|
||
logger.error(f"处理 {pdf_file} 时出错: {str(e)}")
|
||
# 打印更详细的错误信息
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
|
||
logger.info(f"PDF转图片完成,共转换 {len(image_info)} 张图片")
|
||
return image_info
|
||
|
||
def classify_images(image_folder, confidence_threshold=0.6):
|
||
"""
|
||
使用YOLO模型对图片进行分类
|
||
|
||
Args:
|
||
image_folder: 图片所在文件夹路径
|
||
confidence_threshold: 置信度阈值
|
||
|
||
Returns:
|
||
分类结果列表
|
||
"""
|
||
logger.info(f"加载YOLO模型: {YOLO_MODEL_PATH}")
|
||
model = YOLO(YOLO_MODEL_PATH)
|
||
|
||
# 确保使用CPU推理
|
||
model.to('cpu')
|
||
|
||
image_files = [f for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
||
results = []
|
||
|
||
logger.info(f"开始分类 {len(image_files)} 张图片...")
|
||
|
||
for image_file in tqdm(sorted(image_files, key=lambda x: int(os.path.splitext(x)[0])), desc="图片分类"):
|
||
image_path = os.path.join(image_folder, image_file)
|
||
try:
|
||
result = model.predict(image_path, verbose=False)[0]
|
||
|
||
class_id = int(result.probs.top1)
|
||
class_name = result.names[class_id]
|
||
confidence = float(result.probs.top1conf)
|
||
|
||
# 根据类别确定旋转角度和是否为目标图片
|
||
rotation_angle = 0
|
||
if class_name == 'yes_90':
|
||
rotation_angle = -90 # 逆时针旋转90度
|
||
elif class_name == 'yes_-90':
|
||
rotation_angle = 90 # 顺时针旋转90度
|
||
|
||
# 确定是否为目标图片(现在有三种yes类型都是目标)
|
||
is_target = (class_name in ['yes', 'yes_90', 'yes_-90']) and confidence >= confidence_threshold
|
||
|
||
results.append({
|
||
'image_path': image_path,
|
||
'image_name': image_file,
|
||
'class_id': class_id,
|
||
'class_name': class_name,
|
||
'confidence': confidence,
|
||
'rotation_angle': rotation_angle,
|
||
'is_target': is_target
|
||
})
|
||
|
||
logger.info(f"分类 {image_file}: 类别={class_name}, 置信度={confidence:.4f}, 旋转角度={rotation_angle}")
|
||
except Exception as e:
|
||
logger.error(f"处理图片 {image_file} 时出错: {str(e)}")
|
||
continue
|
||
|
||
logger.info(f"图片分类完成,共分类 {len(results)} 张图片")
|
||
return results
|
||
|
||
def rotate_image(image_path, rotation_angle, output_folder=None):
|
||
"""
|
||
根据指定角度旋转图片
|
||
|
||
Args:
|
||
image_path: 原始图片路径
|
||
rotation_angle: 旋转角度(正值表示顺时针,负值表示逆时针)
|
||
output_folder: 输出文件夹,如果为None则使用原图所在文件夹
|
||
|
||
Returns:
|
||
旋转后的图片路径
|
||
"""
|
||
# 如果不需要旋转,直接返回原路径
|
||
if rotation_angle == 0:
|
||
return image_path
|
||
|
||
# 确定输出路径
|
||
image_name = os.path.basename(image_path)
|
||
base_name, ext = os.path.splitext(image_name)
|
||
|
||
if output_folder is None:
|
||
output_folder = os.path.dirname(image_path)
|
||
os.makedirs(output_folder, exist_ok=True)
|
||
|
||
# 构建旋转后图片的文件名
|
||
rotated_image_path = os.path.join(output_folder, f"{base_name}_rotated{ext}")
|
||
|
||
# 读取图片
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
logger.error(f"无法读取图片: {image_path}")
|
||
return image_path
|
||
|
||
# 获取图片尺寸
|
||
height, width = img.shape[:2]
|
||
|
||
# 计算旋转矩阵
|
||
if rotation_angle == 90 or rotation_angle == -90:
|
||
# 对于90度的旋转,交换宽高
|
||
if rotation_angle == 90: # 顺时针90度
|
||
rotated_img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
||
else: # 逆时针90度
|
||
rotated_img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||
else:
|
||
# 其他角度的旋转
|
||
center = (width // 2, height // 2)
|
||
rotation_matrix = cv2.getRotationMatrix2D(center, rotation_angle, 1.0)
|
||
|
||
# 计算旋转后的图像大小
|
||
abs_cos = abs(rotation_matrix[0, 0])
|
||
abs_sin = abs(rotation_matrix[0, 1])
|
||
bound_w = int(height * abs_sin + width * abs_cos)
|
||
bound_h = int(height * abs_cos + width * abs_sin)
|
||
|
||
# 调整旋转矩阵
|
||
rotation_matrix[0, 2] += bound_w / 2 - center[0]
|
||
rotation_matrix[1, 2] += bound_h / 2 - center[1]
|
||
|
||
# 执行旋转
|
||
rotated_img = cv2.warpAffine(img, rotation_matrix, (bound_w, bound_h),
|
||
flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT,
|
||
borderValue=(255, 255, 255))
|
||
|
||
# 保存旋转后的图片
|
||
cv2.imwrite(rotated_image_path, rotated_img)
|
||
logger.info(f"图片已旋转 {rotation_angle} 度并保存至: {rotated_image_path}")
|
||
|
||
return rotated_image_path
|
||
|
||
def detect_layout_and_ocr(image_path, layout_module, ocr_pipeline, output_dir=None):
|
||
"""
|
||
对图片进行版面检测和OCR识别
|
||
|
||
Args:
|
||
image_path: 图片路径
|
||
layout_module: 已加载的版面检测模型
|
||
ocr_pipeline: 已加载的OCR产线
|
||
output_dir: 输出目录,如果为None则自动创建
|
||
|
||
Returns:
|
||
所有识别到的目标区域文本列表
|
||
"""
|
||
# 创建输出目录
|
||
if output_dir is None:
|
||
output_dir = os.path.join("results", "layout_ocr", os.path.splitext(os.path.basename(image_path))[0])
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 获取图片文件名(不含扩展名)
|
||
img_name = os.path.splitext(os.path.basename(image_path))[0]
|
||
|
||
# 执行版面检测
|
||
results_generator = layout_module.predict(image_path, layout_nms=True)
|
||
|
||
# 使用字典存储不同类别的区域
|
||
all_regions = {'name': [], 'money': [], 'title': []}
|
||
target_regions = [] # 存储最终需要处理的区域(name和money)
|
||
result_count = 0
|
||
|
||
# 为每个结果创建目录
|
||
for result_count, result in enumerate(results_generator, 1):
|
||
# 保存结果为JSON
|
||
json_path = os.path.join(output_dir, f"{img_name}_result_{result_count}.json")
|
||
result.save_to_json(json_path)
|
||
|
||
# 从JSON读取检测结果
|
||
try:
|
||
with open(json_path, 'r', encoding='utf-8') as f:
|
||
result_data = json.load(f)
|
||
|
||
# 从result_data中提取目标区域
|
||
# 直接从JSON根级别获取boxes字段
|
||
if 'boxes' in result_data:
|
||
boxes = result_data['boxes']
|
||
logger.debug(f"从JSON中检测到 {len(boxes)} 个区域")
|
||
|
||
# 遍历所有检测到的区域
|
||
for box in boxes:
|
||
label = box.get('label')
|
||
score = box.get('score')
|
||
coordinate = box.get('coordinate')
|
||
|
||
# 根据类别分类区域
|
||
if label in all_regions:
|
||
all_regions[label].append({
|
||
'coordinate': coordinate,
|
||
'score': score
|
||
})
|
||
|
||
# 如果是name或money区域,添加到目标区域列表
|
||
if label in ['name', 'money']:
|
||
target_regions.append({
|
||
'label': label,
|
||
'coordinate': coordinate,
|
||
'score': score
|
||
})
|
||
logger.debug(f"从JSON中找到{label}区域: {coordinate}, 置信度: {score:.4f}")
|
||
|
||
# 记录统计
|
||
for label, regions in all_regions.items():
|
||
if regions:
|
||
logger.info(f"检测到 {len(regions)} 个 {label} 区域")
|
||
|
||
logger.info(f"共有 {len(target_regions)} 个目标区域(name+money)需要OCR处理")
|
||
|
||
except Exception as e:
|
||
logger.error(f"解析JSON结果时出错: {str(e)}")
|
||
|
||
# 如果仍找不到目标区域,返回空列表
|
||
if not target_regions:
|
||
logger.info(f"图片 {img_name} 未找到name或money区域")
|
||
return []
|
||
|
||
# 绘制和保存所有目标区域的可视化
|
||
try:
|
||
img = cv2.imread(image_path)
|
||
if img is not None:
|
||
vis_img = img.copy()
|
||
|
||
# 为不同类别设置不同颜色
|
||
color_map = {
|
||
'name': (0, 255, 0), # 绿色
|
||
'money': (0, 0, 255), # 红色
|
||
'title': (255, 0, 0) # 蓝色
|
||
}
|
||
|
||
# 绘制所有区域,包括title(用于参考)
|
||
for label, regions in all_regions.items():
|
||
for idx, region_info in enumerate(regions):
|
||
coordinate = region_info['coordinate']
|
||
x1, y1, x2, y2 = map(int, coordinate)
|
||
color = color_map.get(label, (128, 128, 128)) # 默认灰色
|
||
cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
|
||
cv2.putText(vis_img, f"{label} {idx+1}", (x1, y1-10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||
|
||
# 保存结果图
|
||
vis_img_path = os.path.join(output_dir, f"{img_name}_all_regions.jpg")
|
||
cv2.imwrite(vis_img_path, vis_img)
|
||
except Exception as e:
|
||
logger.error(f"生成区域可视化时出错: {str(e)}")
|
||
|
||
# 对每个目标区域进行OCR识别
|
||
all_region_texts = [] # 存储所有目标区域的文本内容
|
||
|
||
for idx, region_info in enumerate(target_regions):
|
||
region_label = region_info['label']
|
||
bbox = region_info['coordinate']
|
||
logger.info(f"处理{region_label}区域 {idx+1}...")
|
||
region_texts = [] # 存储当前区域的所有文本
|
||
|
||
try:
|
||
# 读取原图并裁剪区域
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
continue
|
||
|
||
img_height, img_width = img.shape[:2]
|
||
|
||
# 根据不同区域类型设置不同的裁剪边距
|
||
# 只缩小左右距离,保持上下不变
|
||
if region_label == 'name':
|
||
# name区域裁剪参数(缩小左右边距,保持上下不变)
|
||
left_padding = 10 # 左侧缩小像素
|
||
right_padding = 10 # 右侧缩小像素
|
||
top_padding = 0 # 上侧保持不变
|
||
bottom_padding = -7 # 下侧保持不变
|
||
elif region_label == 'money':
|
||
# money区域裁剪参数(更大程度缩小左右边距,保持上下不变)
|
||
left_padding = 20 # 左侧缩小更多像素
|
||
right_padding = 20 # 右侧缩小更多像素
|
||
top_padding = 0 # 上侧保持不变
|
||
bottom_padding = 0 # 下侧保持不变
|
||
else:
|
||
# 其他区域使用默认参数
|
||
left_padding = -7
|
||
right_padding = -7
|
||
top_padding = -7
|
||
bottom_padding = -7
|
||
|
||
x1, y1, x2, y2 = map(int, bbox)
|
||
|
||
# 应用裁剪边距
|
||
x1 = max(0, x1 + left_padding)
|
||
y1 = max(0, y1 + top_padding)
|
||
x2 = min(img_width, x2 - right_padding)
|
||
y2 = min(img_height, y2 - bottom_padding)
|
||
|
||
region_img = img[y1:y2, x1:x2].copy()
|
||
region_img_path = os.path.join(output_dir, f"{img_name}_{region_label}_{idx+1}.jpg")
|
||
cv2.imwrite(region_img_path, region_img)
|
||
|
||
# 执行OCR识别
|
||
ocr_results = ocr_pipeline.predict(region_img_path)
|
||
|
||
# 处理OCR结果
|
||
for res_idx, ocr_res in enumerate(ocr_results, 1):
|
||
# 保存JSON结果
|
||
json_path = os.path.join(output_dir, f"{img_name}_{region_label}_{idx+1}_ocr_{res_idx}.json")
|
||
ocr_res.save_to_json(save_path=json_path)
|
||
|
||
# 读取JSON提取文本
|
||
try:
|
||
with open(json_path, 'r', encoding='utf-8') as f:
|
||
json_data = json.load(f)
|
||
|
||
# 检查是否包含rec_texts字段(顶层)
|
||
if 'rec_texts' in json_data and isinstance(json_data['rec_texts'], list):
|
||
rec_texts = json_data['rec_texts']
|
||
region_texts.extend(rec_texts)
|
||
# 检查列表中的元素
|
||
elif isinstance(json_data, list):
|
||
for item in json_data:
|
||
if isinstance(item, dict):
|
||
if 'rec_texts' in item:
|
||
# 如果是rec_texts数组
|
||
if isinstance(item['rec_texts'], list):
|
||
region_texts.extend(item['rec_texts'])
|
||
else:
|
||
region_texts.append(item['rec_texts'])
|
||
elif 'text' in item:
|
||
# 兼容旧格式
|
||
region_texts.append(item['text'])
|
||
|
||
# 如果JSON直接为空或不包含需要的字段,尝试其他方法
|
||
if not region_texts and hasattr(ocr_res, 'rec_texts'):
|
||
region_texts.extend(ocr_res.rec_texts)
|
||
|
||
except Exception as e:
|
||
logger.error(f"读取OCR JSON结果时出错: {str(e)}")
|
||
|
||
# 尝试直接从对象提取文本信息
|
||
try:
|
||
if hasattr(ocr_res, 'rec_texts'):
|
||
region_texts.extend(ocr_res.rec_texts)
|
||
except:
|
||
logger.error("无法从对象中直接提取文本")
|
||
|
||
# 添加区域标签和文本到总列表
|
||
all_region_texts.append({
|
||
'label': region_label,
|
||
'texts': region_texts
|
||
})
|
||
|
||
logger.info(f"{region_label}区域{idx+1}提取到{len(region_texts)}个文本")
|
||
|
||
# 保存当前区域文本
|
||
txt_path = os.path.join(output_dir, f"{img_name}_{region_label}_{idx+1}_texts.txt")
|
||
with open(txt_path, 'w', encoding='utf-8') as f:
|
||
for text in region_texts:
|
||
f.write(f"{text}\n")
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理{region_label}区域{idx+1}时出错: {str(e)}")
|
||
continue
|
||
|
||
return all_region_texts
|
||
|
||
def merge_region_texts(all_region_texts, output_file):
|
||
"""
|
||
合并所有区域内容为一个txt文件,采用更易于阅读的格式
|
||
|
||
Args:
|
||
all_region_texts: 区域文本列表,每个元素包含label和texts字段
|
||
output_file: 输出文件路径
|
||
"""
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
f.write(f"区域识别结果 - 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||
f.write("=" * 80 + "\n\n")
|
||
|
||
# 按类别分组输出
|
||
for category in ['name', 'money']:
|
||
category_regions = [r for r in all_region_texts if r['label'] == category]
|
||
|
||
if category_regions:
|
||
f.write(f"{category.upper()} 区域 (共{len(category_regions)}个):\n")
|
||
f.write("-" * 50 + "\n\n")
|
||
|
||
for region_idx, region_data in enumerate(category_regions, 1):
|
||
texts = region_data['texts']
|
||
f.write(f"{category} {region_idx} (共{len(texts)}项):\n")
|
||
|
||
# 写入区域内容,每个条目带编号
|
||
for text_idx, text in enumerate(texts, 1):
|
||
f.write(f" {text_idx}. {text}\n")
|
||
|
||
# 如果不是最后一个区域,添加分隔符
|
||
if region_idx < len(category_regions):
|
||
f.write("\n" + "-" * 30 + "\n\n")
|
||
|
||
# 在类别之间添加分隔符
|
||
f.write("\n" + "=" * 50 + "\n\n")
|
||
|
||
logger.info(f"所有区域文本已合并保存至: {output_file}")
|
||
return output_file
|
||
|
||
def post_process_money_texts(money_texts):
|
||
"""
|
||
处理money文本列表,合并被分割的金额
|
||
|
||
Args:
|
||
money_texts: 金额文本列表
|
||
|
||
Returns:
|
||
处理后的金额文本列表
|
||
"""
|
||
if not money_texts:
|
||
return money_texts
|
||
|
||
# 创建新的结果列表
|
||
processed_texts = []
|
||
skip_next = False
|
||
|
||
for i in range(len(money_texts)):
|
||
if skip_next:
|
||
skip_next = False
|
||
continue
|
||
|
||
current_text = money_texts[i].strip()
|
||
|
||
# 检查当前文本是否以小数点结尾或包含小数点后缺少完整的分位
|
||
if i < len(money_texts) - 1:
|
||
next_text = money_texts[i+1].strip()
|
||
needs_merge = False
|
||
|
||
# 情况1: 当前金额以小数点结尾
|
||
if current_text.endswith('.'):
|
||
needs_merge = True
|
||
# 情况2: 当前金额包含小数点但后面不足两位数字
|
||
elif '.' in current_text:
|
||
decimal_part = current_text.split('.')[-1]
|
||
if len(decimal_part) < 2:
|
||
needs_merge = True
|
||
# 情况3: 下一个值是纯数字且长度小于等于2
|
||
is_next_decimal = next_text.isdigit() and len(next_text) <= 2
|
||
# 情况4: 下一个值以0开头且是纯数字(可能是小数部分)
|
||
is_next_leading_zero = next_text.startswith('0') and next_text.replace('0', '').isdigit()
|
||
# 情况5: 当前文本是纯数字,下一个文本以小数点开头
|
||
is_current_integer = current_text.isdigit() and next_text.startswith('.')
|
||
|
||
if (needs_merge and (is_next_decimal or is_next_leading_zero)) or is_current_integer:
|
||
# 合并当前金额和下一个金额
|
||
merged_value = current_text + next_text
|
||
processed_texts.append(merged_value)
|
||
logger.info(f"合并金额: {current_text} + {next_text} = {merged_value}")
|
||
skip_next = True
|
||
continue
|
||
|
||
# 如果没有合并,直接添加当前金额
|
||
processed_texts.append(current_text)
|
||
|
||
# 进行金额格式化
|
||
formatted_texts = []
|
||
for money in processed_texts:
|
||
# 尝试清理和标准化金额格式
|
||
cleaned_money = money.strip()
|
||
|
||
# 移除非数字、小数点和逗号
|
||
cleaned_money = re.sub(r'[^\d.,]', '', cleaned_money)
|
||
|
||
# 检查是否包含小数点,如果没有可能需要添加
|
||
if '.' not in cleaned_money and cleaned_money.isdigit():
|
||
cleaned_money = cleaned_money + '.00'
|
||
logger.info(f"格式化金额: {money} -> {cleaned_money}")
|
||
# 确保小数点后有两位数字
|
||
elif '.' in cleaned_money:
|
||
parts = cleaned_money.split('.')
|
||
if len(parts) == 2:
|
||
integer_part, decimal_part = parts
|
||
# 如果小数部分长度不足2,补零
|
||
if len(decimal_part) < 2:
|
||
decimal_part = decimal_part.ljust(2, '0')
|
||
# 如果小数部分超过2位,截断
|
||
elif len(decimal_part) > 2:
|
||
decimal_part = decimal_part[:2]
|
||
cleaned_money = integer_part + '.' + decimal_part
|
||
logger.info(f"标准化金额: {money} -> {cleaned_money}")
|
||
|
||
# 尝试移除可能被误识别的文本前缀或后缀
|
||
# 查找第一个数字的位置
|
||
first_digit_pos = -1
|
||
for pos, char in enumerate(cleaned_money):
|
||
if char.isdigit() or char == '.':
|
||
first_digit_pos = pos
|
||
break
|
||
|
||
if first_digit_pos > 0:
|
||
cleaned_money = cleaned_money[first_digit_pos:]
|
||
logger.info(f"移除金额前缀: {money} -> {cleaned_money}")
|
||
|
||
# 只保留有效的金额
|
||
try:
|
||
float_value = float(cleaned_money.replace(',', ''))
|
||
formatted_texts.append(cleaned_money)
|
||
except ValueError:
|
||
logger.warning(f"无效金额格式,已忽略: {money}")
|
||
|
||
return formatted_texts
|
||
|
||
def process_pdf_to_table(pdf_folder, output_folder, confidence_threshold=0.6):
|
||
"""
|
||
端到端处理PDF到区域文本的整个流程
|
||
|
||
Args:
|
||
pdf_folder: PDF文件所在文件夹
|
||
output_folder: 输出文件夹
|
||
confidence_threshold: 置信度阈值
|
||
|
||
Returns:
|
||
处理结果统计信息
|
||
"""
|
||
# 创建输出目录
|
||
os.makedirs(output_folder, exist_ok=True)
|
||
images_folder = os.path.join(output_folder, "images")
|
||
results_folder = os.path.join(output_folder, "results")
|
||
rotated_folder = os.path.join(output_folder, "rotated_images")
|
||
os.makedirs(images_folder, exist_ok=True)
|
||
os.makedirs(results_folder, exist_ok=True)
|
||
os.makedirs(rotated_folder, exist_ok=True)
|
||
|
||
# 1. PDF转图片
|
||
logger.info("第1步: PDF转图片")
|
||
image_info = pdf_to_images_bulk(pdf_folder, images_folder)
|
||
|
||
# 2. 图片分类
|
||
logger.info("第2步: 图片分类")
|
||
classification_results = classify_images(images_folder, confidence_threshold)
|
||
|
||
# 3. 筛选目标图片
|
||
target_images = [result for result in classification_results if result['is_target']]
|
||
logger.info(f"筛选出 {len(target_images)} 张目标图片")
|
||
|
||
# 保存分类结果
|
||
classification_csv = os.path.join(results_folder, "classification_results.csv")
|
||
pd.DataFrame(classification_results).to_csv(classification_csv, index=False, encoding='utf-8-sig')
|
||
logger.info(f"分类结果已保存至: {classification_csv}")
|
||
|
||
# 4. 图像旋转处理
|
||
logger.info("第3步: 图像旋转处理")
|
||
for idx, image_info in enumerate(target_images):
|
||
if image_info['rotation_angle'] != 0:
|
||
original_path = image_info['image_path']
|
||
rotated_path = rotate_image(
|
||
original_path,
|
||
image_info['rotation_angle'],
|
||
rotated_folder
|
||
)
|
||
image_info['original_path'] = original_path
|
||
image_info['image_path'] = rotated_path
|
||
logger.info(f"图片 {image_info['image_name']} 已旋转 {image_info['rotation_angle']} 度")
|
||
|
||
# 创建版面检测模型
|
||
logger.info("创建版面检测模型")
|
||
try:
|
||
layout_module = pdx.create_model('PP-DocLayout-L', LAYOUT_MODEL_DIR)
|
||
logger.info(f"使用本地版面检测模型: {LAYOUT_MODEL_DIR}")
|
||
except Exception as e:
|
||
logger.error(f"加载本地版面检测模型失败: {str(e)}")
|
||
raise
|
||
|
||
# 创建OCR产线
|
||
logger.info("创建OCR产线")
|
||
try:
|
||
ocr_pipeline = pdx.create_pipeline(pipeline=OCR_CONFIG_PATH)
|
||
except Exception as e:
|
||
logger.error(f"创建OCR产线失败: {str(e)}")
|
||
raise
|
||
|
||
# 5. 对每张目标图片进行版面检测和OCR识别
|
||
logger.info("第4步: 版面检测和OCR识别")
|
||
processing_results = []
|
||
|
||
for idx, image_info in enumerate(tqdm(target_images, desc="处理目标图片")):
|
||
image_path = image_info['image_path']
|
||
image_name = image_info['image_name']
|
||
rotation_angle = image_info.get('rotation_angle', 0)
|
||
|
||
# 创建图片特定的输出目录
|
||
image_output_dir = os.path.join(results_folder, os.path.splitext(image_name)[0])
|
||
|
||
# 执行版面检测和OCR
|
||
region_texts = detect_layout_and_ocr(image_path, layout_module, ocr_pipeline, image_output_dir)
|
||
|
||
# 如果找到目标区域,合并并保存结果
|
||
if region_texts:
|
||
# 后处理money文本,合并被分割的金额
|
||
for i, region in enumerate(region_texts):
|
||
if region['label'] == 'money':
|
||
region_texts[i]['texts'] = post_process_money_texts(region['texts'])
|
||
logger.info(f"处理后的金额数量: {len(region_texts[i]['texts'])}")
|
||
|
||
merged_txt_path = os.path.join(results_folder, f"{os.path.splitext(image_name)[0]}_regions.txt")
|
||
merge_region_texts(region_texts, merged_txt_path)
|
||
|
||
# 统计各类别区域数量
|
||
name_regions = sum(1 for r in region_texts if r['label'] == 'name')
|
||
money_regions = sum(1 for r in region_texts if r['label'] == 'money')
|
||
total_texts = sum(len(r['texts']) for r in region_texts)
|
||
|
||
processing_results.append({
|
||
'image_path': image_path,
|
||
'image_name': image_name,
|
||
'rotation_angle': rotation_angle,
|
||
'name_regions': name_regions,
|
||
'money_regions': money_regions,
|
||
'total_regions': len(region_texts),
|
||
'total_texts': total_texts,
|
||
'output_file': merged_txt_path
|
||
})
|
||
else:
|
||
logger.info(f"图片 {image_name} 未找到目标区域")
|
||
processing_results.append({
|
||
'image_path': image_path,
|
||
'image_name': image_name,
|
||
'rotation_angle': rotation_angle,
|
||
'name_regions': 0,
|
||
'money_regions': 0,
|
||
'total_regions': 0,
|
||
'total_texts': 0,
|
||
'output_file': None
|
||
})
|
||
|
||
# 保存处理结果
|
||
processing_csv = os.path.join(results_folder, "processing_results.csv")
|
||
pd.DataFrame(processing_results).to_csv(processing_csv, index=False, encoding='utf-8-sig')
|
||
logger.info(f"处理结果已保存至: {processing_csv}")
|
||
|
||
# 统计信息
|
||
stats = {
|
||
'total_images': len(image_info),
|
||
'target_images': len(target_images),
|
||
'rotated_images': sum(1 for img in target_images if img.get('rotation_angle', 0) != 0),
|
||
'successful_extractions': sum(1 for r in processing_results if r['total_regions'] > 0),
|
||
'total_name_regions': sum(r['name_regions'] for r in processing_results),
|
||
'total_money_regions': sum(r['money_regions'] for r in processing_results),
|
||
'total_text_elements': sum(r['total_texts'] for r in processing_results)
|
||
}
|
||
|
||
logger.info("处理统计")
|
||
logger.info(f"总图片数: {stats['total_images']}")
|
||
logger.info(f"目标图片数: {stats['target_images']}")
|
||
logger.info(f"旋转处理的图片数: {stats['rotated_images']}")
|
||
logger.info(f"成功提取区域的图片数: {stats['successful_extractions']}")
|
||
logger.info(f"总name区域数: {stats['total_name_regions']}")
|
||
logger.info(f"总money区域数: {stats['total_money_regions']}")
|
||
logger.info(f"总文本元素数: {stats['total_text_elements']}")
|
||
|
||
return stats
|
||
|
||
if __name__ == "__main__":
|
||
# 示例使用
|
||
try:
|
||
stats = process_pdf_to_table(
|
||
pdf_folder="E:\ocr_cpu\pdf", # 输入PDF文件夹
|
||
output_folder="output", # 输出文件夹
|
||
confidence_threshold=0.6 # 置信度阈值
|
||
)
|
||
print("处理完成!")
|
||
print(f"目标图片数: {stats['target_images']}")
|
||
print(f"成功提取区域的图片数: {stats['successful_extractions']}")
|
||
except Exception as e:
|
||
logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
|
||
sys.exit(1) |