上传文件至 /
This commit is contained in:
parent
d945a732ed
commit
1ebdd0845b
|
|
@ -0,0 +1,698 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import glob
|
||||
import re
|
||||
import argparse
|
||||
import shutil
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pdf_to_table import process_pdf_to_table, post_process_money_texts
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_regions_file(txt_file):
|
||||
"""从regions文本文件提取name和money信息"""
|
||||
name_data = []
|
||||
money_data = []
|
||||
current_section = None
|
||||
|
||||
try:
|
||||
logger.info(f"开始解析文件: {txt_file}")
|
||||
|
||||
# 使用更高效的文件读取方式
|
||||
with open(txt_file, 'r', encoding='utf-8') as f:
|
||||
# 一次性读取所有行,减少IO操作
|
||||
lines = f.readlines()
|
||||
|
||||
# 预先定义正则表达式进行匹配,提高效率
|
||||
item_pattern = re.compile(r'^\s+(\d+)\.\s+(.+)$')
|
||||
|
||||
# 直接遍历处理,避免不必要的日志记录
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if 'NAME 区域' in line:
|
||||
current_section = 'name'
|
||||
continue
|
||||
elif 'MONEY 区域' in line:
|
||||
current_section = 'money'
|
||||
continue
|
||||
elif '===' in line: # 区域分隔符,重置当前部分
|
||||
current_section = None
|
||||
continue
|
||||
|
||||
# 使用正则表达式匹配,更高效
|
||||
if current_section:
|
||||
match = item_pattern.match(line)
|
||||
if match:
|
||||
content = match.group(2).strip()
|
||||
if current_section == 'name':
|
||||
name_data.append(content)
|
||||
elif current_section == 'money':
|
||||
money_data.append(content)
|
||||
|
||||
# 只在处理完成后记录一次结果
|
||||
logger.info(f"从{txt_file}解析结果: {len(name_data)}个名称, {len(money_data)}个金额")
|
||||
|
||||
# 只有当实际有内容时才记录示例
|
||||
if name_data and len(name_data) > 0:
|
||||
examples = name_data[:min(3, len(name_data))]
|
||||
logger.info(f"名称示例: {examples}")
|
||||
if money_data and len(money_data) > 0:
|
||||
examples = money_data[:min(3, len(money_data))]
|
||||
logger.info(f"金额示例: {examples}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析结果文件时出错: {str(e)}")
|
||||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||
|
||||
return {
|
||||
'name': name_data,
|
||||
'money': money_data
|
||||
}
|
||||
|
||||
def extract_texts_from_directory(directory, image_prefix):
|
||||
"""
|
||||
直接从目录中的文本文件提取name和money文本
|
||||
作为parse_regions_file的备用方法
|
||||
|
||||
Args:
|
||||
directory: 包含文本文件的目录路径
|
||||
image_prefix: 图像前缀(如"5"表示提取5_name_*.txt和5_money_*.txt文件)
|
||||
|
||||
Returns:
|
||||
包含name和money列表的字典
|
||||
"""
|
||||
name_data = []
|
||||
money_data = []
|
||||
|
||||
try:
|
||||
# 使用glob模式直接查找匹配的文件
|
||||
name_patterns = [
|
||||
os.path.join(directory, f"{image_prefix}_name_*_texts.txt"),
|
||||
os.path.join(directory, f"{image_prefix}_rotated_name_*_texts.txt")
|
||||
]
|
||||
money_patterns = [
|
||||
os.path.join(directory, f"{image_prefix}_money_*_texts.txt"),
|
||||
os.path.join(directory, f"{image_prefix}_rotated_money_*_texts.txt")
|
||||
]
|
||||
|
||||
name_files = []
|
||||
for pattern in name_patterns:
|
||||
name_files.extend(glob.glob(pattern))
|
||||
|
||||
money_files = []
|
||||
for pattern in money_patterns:
|
||||
money_files.extend(glob.glob(pattern))
|
||||
|
||||
# 如果在当前目录找不到,检查子目录
|
||||
if not name_files and not money_files:
|
||||
image_dir = os.path.join(directory, image_prefix)
|
||||
if os.path.isdir(image_dir):
|
||||
# 同样支持旋转与非旋转的文件格式
|
||||
name_patterns = [
|
||||
os.path.join(image_dir, f"{image_prefix}_name_*_texts.txt"),
|
||||
os.path.join(image_dir, f"{image_prefix}_rotated_name_*_texts.txt")
|
||||
]
|
||||
money_patterns = [
|
||||
os.path.join(image_dir, f"{image_prefix}_money_*_texts.txt"),
|
||||
os.path.join(image_dir, f"{image_prefix}_rotated_money_*_texts.txt")
|
||||
]
|
||||
|
||||
name_files = []
|
||||
for pattern in name_patterns:
|
||||
name_files.extend(glob.glob(pattern))
|
||||
|
||||
money_files = []
|
||||
for pattern in money_patterns:
|
||||
money_files.extend(glob.glob(pattern))
|
||||
|
||||
# 定义排序函数
|
||||
def extract_index(file_path):
|
||||
try:
|
||||
# 使用split+basename更高效
|
||||
parts = os.path.basename(file_path).split('_')
|
||||
# 处理旋转图片的情况
|
||||
if 'rotated' in parts:
|
||||
# 如果是rotated文件,则索引位置会不同
|
||||
idx_pos = 3 if len(parts) >= 4 else 0
|
||||
else:
|
||||
# 非旋转图片的正常情况
|
||||
idx_pos = 2 if len(parts) >= 3 else 0
|
||||
return int(parts[idx_pos]) if idx_pos < len(parts) else 0
|
||||
except (IndexError, ValueError):
|
||||
return float('inf')
|
||||
|
||||
# 对文件列表进行排序
|
||||
name_files.sort(key=extract_index)
|
||||
money_files.sort(key=extract_index)
|
||||
|
||||
# 只有在找到文件时才记录
|
||||
if name_files or money_files:
|
||||
logger.info(f"找到 {len(name_files)} 个name文本文件和 {len(money_files)} 个money文本文件")
|
||||
|
||||
# 使用更高效的读取方式
|
||||
for file_path in name_files:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
# 一次性读取并处理
|
||||
lines = [line.strip() for line in f if line.strip()]
|
||||
name_data.extend(lines)
|
||||
logger.info(f"从 {os.path.basename(file_path)} 提取了 {len(lines)} 个名称")
|
||||
|
||||
for file_path in money_files:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
# 一次性读取并处理
|
||||
lines = [line.strip() for line in f if line.strip()]
|
||||
money_data.extend(lines)
|
||||
logger.info(f"从 {os.path.basename(file_path)} 提取了 {len(lines)} 个金额")
|
||||
|
||||
if name_data or money_data:
|
||||
logger.info(f"直接从文本文件提取: {len(name_data)}个名称, {len(money_data)}个金额")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从目录提取文本时出错: {str(e)}")
|
||||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||
|
||||
return {
|
||||
'name': name_data,
|
||||
'money': money_data
|
||||
}
|
||||
|
||||
def process_pdf(pdf_path, output_dir, confidence_threshold=0.6):
|
||||
"""
|
||||
处理单个PDF文件
|
||||
|
||||
Args:
|
||||
pdf_path: PDF文件路径
|
||||
output_dir: 输出目录
|
||||
confidence_threshold: 置信度阈值
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
try:
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 获取PDF文件名
|
||||
pdf_filename = os.path.basename(pdf_path)
|
||||
pdf_name = os.path.splitext(pdf_filename)[0]
|
||||
|
||||
logger.info(f"开始处理PDF: {pdf_path}")
|
||||
|
||||
# 获取PDF文件大小
|
||||
file_size = os.path.getsize(pdf_path) / (1024 * 1024) # 转换为MB
|
||||
logger.info(f"文件大小: {file_size:.2f} MB")
|
||||
|
||||
# 创建临时目录存放PDF
|
||||
pdf_folder = os.path.join(output_dir, "temp_pdf")
|
||||
os.makedirs(pdf_folder, exist_ok=True)
|
||||
|
||||
# 复制PDF到临时目录
|
||||
temp_pdf_path = os.path.join(pdf_folder, pdf_filename)
|
||||
shutil.copy2(pdf_path, temp_pdf_path)
|
||||
|
||||
# 记录处理开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 处理PDF文件
|
||||
stats = process_pdf_to_table(
|
||||
pdf_folder=pdf_folder,
|
||||
output_folder=output_dir,
|
||||
confidence_threshold=confidence_threshold
|
||||
)
|
||||
|
||||
# 计算处理时间
|
||||
processing_time = time.time() - start_time
|
||||
logger.info(f"处理耗时: {processing_time:.2f} 秒")
|
||||
|
||||
# 记录性能数据
|
||||
stats['performance'] = {
|
||||
'file_size_mb': round(file_size, 2),
|
||||
'processing_time_seconds': round(processing_time, 2),
|
||||
'processing_speed_mb_per_second': round(file_size / processing_time, 4) if processing_time > 0 else 0
|
||||
}
|
||||
|
||||
# 解析结果
|
||||
all_results = []
|
||||
results_folder = os.path.join(output_dir, "results")
|
||||
|
||||
if os.path.exists(results_folder):
|
||||
# 使用glob快速查找所有regions.txt文件
|
||||
region_files = glob.glob(os.path.join(results_folder, "*_regions.txt"))
|
||||
|
||||
# 按文件名中的数字进行排序
|
||||
def extract_number(file_path):
|
||||
try:
|
||||
filename = os.path.basename(file_path)
|
||||
number_str = filename.split('_')[0]
|
||||
return int(number_str)
|
||||
except (IndexError, ValueError):
|
||||
return float('inf')
|
||||
|
||||
# 按文件名中的数字从小到大排序
|
||||
region_files.sort(key=extract_number)
|
||||
|
||||
if region_files:
|
||||
logger.info(f"找到 {len(region_files)} 个区域文件用于解析")
|
||||
|
||||
# 顺序处理区域文件
|
||||
for region_file in region_files:
|
||||
image_prefix = os.path.basename(region_file).split('_')[0]
|
||||
result = parse_regions_file(region_file)
|
||||
|
||||
# 只有当解析到内容时才添加结果
|
||||
if result['name'] or result['money']:
|
||||
# 对当前图片的数据进行处理和匹配
|
||||
result = process_single_image_data(result, image_prefix)
|
||||
all_results.append(result)
|
||||
logger.info(f"从 {image_prefix} 解析到内容,并完成了匹配处理")
|
||||
else:
|
||||
# 备用方法:直接从texts.txt文件读取
|
||||
backup_result = extract_texts_from_directory(results_folder, image_prefix)
|
||||
if backup_result['name'] or backup_result['money']:
|
||||
# 对备用方法提取的数据也进行处理和匹配
|
||||
backup_result = process_single_image_data(backup_result, image_prefix)
|
||||
all_results.append(backup_result)
|
||||
logger.info(f"使用备用方法从 {image_prefix} 成功提取内容,并完成了匹配处理")
|
||||
else:
|
||||
logger.warning("未找到任何regions.txt文件")
|
||||
|
||||
# 如果没有从regions.txt解析到结果,尝试直接从子目录提取
|
||||
if not all_results:
|
||||
logger.warning("未从regions.txt文件解析到内容,尝试从子目录直接提取")
|
||||
|
||||
# 获取所有子目录并排序
|
||||
subdirs = []
|
||||
if os.path.exists(results_folder):
|
||||
subdirs = [d for d in os.listdir(results_folder)
|
||||
if os.path.isdir(os.path.join(results_folder, d))]
|
||||
|
||||
# 按目录名中的数字排序
|
||||
def extract_dir_number(dirname):
|
||||
try:
|
||||
return int(dirname)
|
||||
except ValueError:
|
||||
return float('inf')
|
||||
|
||||
subdirs.sort(key=extract_dir_number)
|
||||
|
||||
if subdirs:
|
||||
logger.info(f"找到 {len(subdirs)} 个子目录,准备提取")
|
||||
|
||||
# 使用顺序处理子目录
|
||||
ordered_results = []
|
||||
for dirname in subdirs:
|
||||
result_dir = os.path.join(results_folder, dirname)
|
||||
backup_result = extract_texts_from_directory(result_dir, dirname)
|
||||
if backup_result['name'] or backup_result['money']:
|
||||
# 对子目录数据进行处理和匹配
|
||||
backup_result = process_single_image_data(backup_result, dirname)
|
||||
ordered_results.append((int(dirname) if dirname.isdigit() else float('inf'), backup_result))
|
||||
|
||||
# 按子目录编号排序后添加到结果
|
||||
ordered_results.sort(key=lambda x: x[0])
|
||||
for _, result in ordered_results:
|
||||
all_results.append(result)
|
||||
|
||||
# 合并所有结果
|
||||
merged_result = {'name': [], 'money': [], 'name_money': []}
|
||||
|
||||
if all_results:
|
||||
# 合并所有图片的处理结果,保持每个图片内的匹配关系
|
||||
for result in all_results:
|
||||
merged_result['name'].extend(result.get('name', []))
|
||||
merged_result['money'].extend(result.get('money', []))
|
||||
merged_result['name_money'].extend(result.get('name_money', []))
|
||||
|
||||
# 注意:由于每个图片已经单独处理并匹配,此处不再需要全局匹配逻辑
|
||||
|
||||
# 记录处理结果摘要
|
||||
logger.info(f"PDF {pdf_filename} 处理完成,找到 {len(merged_result['name'])} 个名称, {len(merged_result['money'])} 个金额, {len(merged_result['name_money'])} 个名称-金额对")
|
||||
|
||||
# 保存最终结果到JSON文件
|
||||
result_json_path = os.path.join(output_dir, f"{pdf_name}_result.json")
|
||||
with open(result_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
'pdf_file': pdf_filename,
|
||||
'name': merged_result['name'],
|
||||
'money': merged_result['money'],
|
||||
'name_money': merged_result['name_money'],
|
||||
'stats': stats
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"JSON结果已保存到: {result_json_path}")
|
||||
|
||||
# 保存最终结果到TXT文件
|
||||
result_txt_path = os.path.join(output_dir, f"{pdf_name}_result.txt")
|
||||
with open(result_txt_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"PDF文件:{pdf_filename} 处理结果\n")
|
||||
f.write(f"处理时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write("=" * 80 + "\n\n")
|
||||
|
||||
# 写入name-money对
|
||||
f.write(f"共找到 {len(merged_result['name_money'])} 个名称-金额对:\n\n")
|
||||
for i, pair in enumerate(merged_result['name_money'], 1):
|
||||
f.write(f"{i}. {pair}\n")
|
||||
|
||||
f.write("\n" + "=" * 80 + "\n")
|
||||
logger.info(f"TXT结果已保存到: {result_txt_path}")
|
||||
|
||||
# 清理临时目录
|
||||
try:
|
||||
shutil.rmtree(pdf_folder)
|
||||
logger.info("已清理临时PDF目录")
|
||||
except Exception as e:
|
||||
logger.warning(f"清理临时目录时出错: {str(e)}")
|
||||
|
||||
return {
|
||||
'pdf_file': pdf_filename,
|
||||
'result': merged_result,
|
||||
'stats': stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理PDF {pdf_path} 时出错: {str(e)}")
|
||||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||
return {
|
||||
'pdf_file': os.path.basename(pdf_path),
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def process_single_image_data(result, image_prefix):
|
||||
"""
|
||||
处理单张图片的名称和金额数据,解决匹配问题
|
||||
|
||||
Args:
|
||||
result: 包含name和money的字典
|
||||
image_prefix: 图片前缀标识
|
||||
|
||||
Returns:
|
||||
处理后的结果字典,包含匹配好的name_money对
|
||||
"""
|
||||
logger.info(f"开始处理图片 {image_prefix} 的数据匹配")
|
||||
|
||||
# 复制原始数据,避免修改原始数据
|
||||
processed_result = {
|
||||
'name': result.get('name', []).copy(),
|
||||
'money': result.get('money', []).copy(),
|
||||
'name_money': []
|
||||
}
|
||||
|
||||
# 对金额应用后处理函数 - 合并被分割的金额(从pdf_to_table模块导入)
|
||||
try:
|
||||
# 对金额进行专门的后处理
|
||||
processed_result['money'] = post_process_money_texts(processed_result['money'])
|
||||
logger.info(f"图片 {image_prefix} 金额后处理后数量: {len(processed_result['money'])}")
|
||||
except ImportError:
|
||||
logger.warning("无法导入post_process_money_texts函数,将使用内部金额处理逻辑")
|
||||
# 内部处理逻辑 - 合并分散的金额
|
||||
processed_money = []
|
||||
i = 0
|
||||
while i < len(processed_result['money']):
|
||||
current = processed_result['money'][i].strip()
|
||||
|
||||
# 检查当前金额是否需要与下一个合并
|
||||
if i < len(processed_result['money']) - 1:
|
||||
next_money = processed_result['money'][i+1].strip()
|
||||
|
||||
# 情况1: 当前金额以小数点结尾
|
||||
if current.endswith('.'):
|
||||
merged_value = current + next_money
|
||||
processed_money.append(merged_value)
|
||||
logger.info(f"合并金额: {current} + {next_money} = {merged_value}")
|
||||
i += 2 # 跳过下一个已合并的金额
|
||||
continue
|
||||
|
||||
# 情况2: 当前金额包含小数点但后面不足两位数字
|
||||
elif '.' in current:
|
||||
decimal_part = current.split('.')[-1]
|
||||
if len(decimal_part) < 2:
|
||||
merged_value = current + next_money
|
||||
processed_money.append(merged_value)
|
||||
logger.info(f"合并小数金额: {current} + {next_money} = {merged_value}")
|
||||
i += 2
|
||||
continue
|
||||
|
||||
# 情况3: 下一个是纯数字且长度<=2(可能是分离的小数部分)
|
||||
elif next_money.isdigit() and len(next_money) <= 2:
|
||||
# 检查当前值是否为整数或以小数点结尾的数
|
||||
if current.isdigit() or current.endswith('.'):
|
||||
merged_value = current + next_money
|
||||
processed_money.append(merged_value)
|
||||
logger.info(f"合并可能的小数部分: {current} + {next_money} = {merged_value}")
|
||||
i += 2
|
||||
continue
|
||||
|
||||
# 情况4: 可能是分散的小数点(例如"6500"和".00"分开)
|
||||
elif not '.' in current and next_money.startswith('.') and next_money[1:].isdigit():
|
||||
merged_value = current + next_money
|
||||
processed_money.append(merged_value)
|
||||
logger.info(f"合并分散小数点: {current} + {next_money} = {merged_value}")
|
||||
i += 2
|
||||
continue
|
||||
|
||||
# 如果没有合并,直接添加当前金额
|
||||
processed_money.append(current)
|
||||
i += 1
|
||||
|
||||
# 更新处理后的金额列表
|
||||
processed_result['money'] = processed_money
|
||||
|
||||
# 金额格式化和验证 - 确保金额格式正确
|
||||
validated_money = []
|
||||
for money in processed_result['money']:
|
||||
# 尝试清理和标准化金额格式
|
||||
cleaned_money = money.strip()
|
||||
|
||||
# 移除可能的非数字字符(保留数字、小数点和逗号)
|
||||
cleaned_money = re.sub(r'[^\d.,]', '', cleaned_money)
|
||||
|
||||
# 检查是否包含小数点,如果没有可能需要添加
|
||||
if '.' not in cleaned_money and cleaned_money.isdigit():
|
||||
cleaned_money = cleaned_money + '.00'
|
||||
logger.info(f"格式化金额: {money} -> {cleaned_money}")
|
||||
# 确保小数点后有两位数字
|
||||
elif '.' in cleaned_money:
|
||||
parts = cleaned_money.split('.')
|
||||
if len(parts) == 2:
|
||||
integer_part, decimal_part = parts
|
||||
# 如果小数部分长度不足2,补零
|
||||
if len(decimal_part) < 2:
|
||||
decimal_part = decimal_part.ljust(2, '0')
|
||||
# 如果小数部分超过2位,可能需要截断或四舍五入(根据需求)
|
||||
elif len(decimal_part) > 2:
|
||||
decimal_part = decimal_part[:2]
|
||||
cleaned_money = integer_part + '.' + decimal_part
|
||||
logger.info(f"标准化金额: {money} -> {cleaned_money}")
|
||||
|
||||
# 检查金额格式是否有效
|
||||
try:
|
||||
float_value = float(cleaned_money.replace(',', ''))
|
||||
if float_value < 0.01:
|
||||
logger.warning(f"忽略异常小的金额值: {cleaned_money}")
|
||||
continue
|
||||
validated_money.append(cleaned_money)
|
||||
except ValueError:
|
||||
logger.warning(f"忽略无效金额格式: {cleaned_money}")
|
||||
|
||||
# 更新为验证后的金额
|
||||
processed_result['money'] = validated_money
|
||||
|
||||
# 处理name和money数量不匹配的情况
|
||||
max_name_len = len(processed_result['name'])
|
||||
max_money_len = len(processed_result['money'])
|
||||
|
||||
# 记录name和money配对前的情况
|
||||
logger.info(f"图片 {image_prefix} 配对前: {max_name_len}个名称, {max_money_len}个金额")
|
||||
|
||||
# 判断是否存在不匹配
|
||||
if max_name_len != max_money_len:
|
||||
logger.warning(f"图片 {image_prefix} 名称和金额数量不匹配: {max_name_len}个名称 vs {max_money_len}个金额")
|
||||
|
||||
# 处理姓名中可能存在的误识别情况
|
||||
cleaned_names = []
|
||||
for name in processed_result['name']:
|
||||
# 判断姓名中是否包含错误识别的数字和非中文字符
|
||||
clean_name = re.sub(r'[0-9a-zA-Z.,:;!@#$%^&*()_+={}\[\]|\\/<>?~`-]', '', name)
|
||||
if clean_name != name:
|
||||
logger.info(f"清理姓名中的误识别: {name} -> {clean_name}")
|
||||
|
||||
# 如果清理后不为空,则保留
|
||||
if clean_name.strip():
|
||||
cleaned_names.append(clean_name.strip())
|
||||
else:
|
||||
logger.warning(f"姓名'{name}'在清理后为空,已移除")
|
||||
|
||||
# 更新清理后的姓名列表
|
||||
processed_result['name'] = cleaned_names
|
||||
|
||||
# 再次检查数量
|
||||
max_name_len = len(processed_result['name'])
|
||||
logger.info(f"清理后: {max_name_len}个名称, {max_money_len}个金额")
|
||||
|
||||
# 创建最终的配对结果
|
||||
# 采用截断方式匹配,确保数据不会错位
|
||||
match_len = min(max_name_len, max_money_len)
|
||||
|
||||
if match_len > 0:
|
||||
for i in range(match_len):
|
||||
name = processed_result['name'][i]
|
||||
money = processed_result['money'][i]
|
||||
pair = f"{name}-{money}"
|
||||
processed_result['name_money'].append(pair)
|
||||
logger.info(f"图片 {image_prefix} 匹配对 #{i+1}: {pair}")
|
||||
|
||||
# 如果有多余的名称或金额,记录日志但不匹配
|
||||
if max_name_len > match_len:
|
||||
extra_names = processed_result['name'][match_len:]
|
||||
logger.warning(f"图片 {image_prefix} 有 {len(extra_names)} 个未匹配的名称: {extra_names}")
|
||||
|
||||
if max_money_len > match_len:
|
||||
extra_money = processed_result['money'][match_len:]
|
||||
logger.warning(f"图片 {image_prefix} 有 {len(extra_money)} 个未匹配的金额: {extra_money}")
|
||||
|
||||
logger.info(f"图片 {image_prefix} 处理完成,共生成 {len(processed_result['name_money'])} 个名称-金额对")
|
||||
|
||||
return processed_result
|
||||
|
||||
def process_pdfs_in_directory(input_dir, output_dir, confidence_threshold=0.6):
|
||||
"""
|
||||
处理目录中的所有PDF文件
|
||||
|
||||
Args:
|
||||
input_dir: 输入目录,包含PDF文件
|
||||
output_dir: 输出目录
|
||||
confidence_threshold: 置信度阈值
|
||||
|
||||
Returns:
|
||||
处理结果统计
|
||||
"""
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 查找所有PDF文件
|
||||
pdf_files = [f for f in os.listdir(input_dir) if f.lower().endswith('.pdf')]
|
||||
|
||||
if not pdf_files:
|
||||
logger.warning(f"在 {input_dir} 目录中未找到PDF文件")
|
||||
return {'processed': 0, 'success': 0, 'failed': 0, 'files': []}
|
||||
|
||||
logger.info(f"在 {input_dir} 目录中找到 {len(pdf_files)} 个PDF文件")
|
||||
|
||||
# 按照文件名排序,确保处理顺序一致
|
||||
pdf_files.sort()
|
||||
|
||||
results = []
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
|
||||
# 顺序处理每个PDF
|
||||
for pdf_file in pdf_files:
|
||||
pdf_path = os.path.join(input_dir, pdf_file)
|
||||
pdf_name = os.path.splitext(pdf_file)[0]
|
||||
|
||||
# 为每个PDF创建单独的输出目录
|
||||
pdf_output_dir = os.path.join(output_dir, pdf_name)
|
||||
|
||||
logger.info(f"开始处理 {pdf_file} ({len(results)+1}/{len(pdf_files)})")
|
||||
|
||||
# 处理PDF
|
||||
result = process_pdf(pdf_path, pdf_output_dir, confidence_threshold)
|
||||
|
||||
if 'error' in result:
|
||||
logger.error(f"处理 {pdf_file} 失败: {result['error']}")
|
||||
failed_count += 1
|
||||
else:
|
||||
logger.info(f"成功处理 {pdf_file}")
|
||||
success_count += 1
|
||||
|
||||
results.append(result)
|
||||
|
||||
# 创建汇总报告
|
||||
summary = {
|
||||
'processed': len(pdf_files),
|
||||
'success': success_count,
|
||||
'failed': failed_count,
|
||||
'files': [r['pdf_file'] for r in results]
|
||||
}
|
||||
|
||||
# 保存汇总报告
|
||||
summary_path = os.path.join(output_dir, "processing_summary.json")
|
||||
with open(summary_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"处理完成,共处理 {len(pdf_files)} 个PDF,成功 {success_count} 个,失败 {failed_count} 个")
|
||||
logger.info(f"汇总报告已保存到: {summary_path}")
|
||||
|
||||
return summary
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='OCR处理工具:处理PDF并提取姓名和金额')
|
||||
|
||||
# 添加参数
|
||||
parser.add_argument('--input', '-i', required=True, help='输入PDF文件路径或包含PDF文件的目录')
|
||||
parser.add_argument('--output', '-o', required=True, help='输出目录')
|
||||
parser.add_argument('--confidence', '-c', type=float, default=0.6, help='置信度阈值,默认0.6')
|
||||
|
||||
# 解析参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 检查输入路径
|
||||
if not os.path.exists(args.input):
|
||||
logger.error(f"输入路径不存在: {args.input}")
|
||||
return 1
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
# 处理单个文件或目录
|
||||
if os.path.isfile(args.input):
|
||||
if not args.input.lower().endswith('.pdf'):
|
||||
logger.error(f"输入文件不是PDF: {args.input}")
|
||||
return 1
|
||||
|
||||
logger.info(f"开始处理单个PDF文件: {args.input}")
|
||||
result = process_pdf(args.input, args.output, args.confidence)
|
||||
|
||||
if 'error' in result:
|
||||
logger.error(f"处理失败: {result['error']}")
|
||||
return 1
|
||||
else:
|
||||
logger.info("处理成功")
|
||||
else:
|
||||
logger.info(f"开始处理目录中的PDF文件: {args.input}")
|
||||
summary = process_pdfs_in_directory(args.input, args.output, args.confidence)
|
||||
|
||||
if summary['success'] == 0:
|
||||
logger.warning("没有成功处理任何PDF文件")
|
||||
if summary['processed'] > 0:
|
||||
return 1
|
||||
|
||||
# 计算总处理时间
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"总处理时间: {total_time:.2f} 秒")
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
sys.exit(main())
|
||||
except Exception as e:
|
||||
logger.error(f"程序执行过程中发生错误: {str(e)}")
|
||||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||
sys.exit(1)
|
||||
|
|
@ -0,0 +1,803 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import glob
|
||||
import re
|
||||
import argparse
|
||||
import shutil
|
||||
import traceback
|
||||
import concurrent.futures
|
||||
from datetime import datetime
|
||||
from pdf_to_table import process_pdf_to_table, post_process_money_texts
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 优化批量处理的线程数
|
||||
DEFAULT_MAX_WORKERS = max(4, os.cpu_count() or 4)
|
||||
|
||||
# 设置一个环境变量控制日志详细程度
|
||||
VERBOSE_LOGGING = os.environ.get('OCR_VERBOSE_LOGGING', '0') == '1'
|
||||
|
||||
def log_verbose(message):
|
||||
"""只在详细日志模式下记录信息"""
|
||||
if VERBOSE_LOGGING:
|
||||
logger.info(message)
|
||||
|
||||
def parse_regions_file(txt_file):
|
||||
"""从regions文本文件提取name和money信息"""
|
||||
name_data = []
|
||||
money_data = []
|
||||
current_section = None
|
||||
|
||||
try:
|
||||
logger.info(f"开始解析文件: {txt_file}")
|
||||
|
||||
# 使用更高效的文件读取方式
|
||||
with open(txt_file, 'r', encoding='utf-8') as f:
|
||||
# 一次性读取所有行,减少IO操作
|
||||
lines = f.readlines()
|
||||
|
||||
# 预先定义正则表达式进行匹配,提高效率
|
||||
item_pattern = re.compile(r'^\s+(\d+)\.\s+(.+)$')
|
||||
|
||||
# 直接遍历处理,避免不必要的日志记录
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if 'NAME 区域' in line:
|
||||
current_section = 'name'
|
||||
continue
|
||||
elif 'MONEY 区域' in line:
|
||||
current_section = 'money'
|
||||
continue
|
||||
elif '===' in line: # 区域分隔符,重置当前部分
|
||||
current_section = None
|
||||
continue
|
||||
|
||||
# 使用正则表达式匹配,更高效
|
||||
if current_section:
|
||||
match = item_pattern.match(line)
|
||||
if match:
|
||||
content = match.group(2).strip()
|
||||
if current_section == 'name':
|
||||
name_data.append(content)
|
||||
elif current_section == 'money':
|
||||
money_data.append(content)
|
||||
|
||||
# 只在处理完成后记录一次结果
|
||||
logger.info(f"从{txt_file}解析结果: {len(name_data)}个名称, {len(money_data)}个金额")
|
||||
|
||||
# 只有当实际有内容时才记录示例
|
||||
if name_data and len(name_data) > 0:
|
||||
examples = name_data[:min(3, len(name_data))]
|
||||
logger.info(f"名称示例: {examples}")
|
||||
if money_data and len(money_data) > 0:
|
||||
examples = money_data[:min(3, len(money_data))]
|
||||
logger.info(f"金额示例: {examples}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析结果文件时出错: {str(e)}")
|
||||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||
|
||||
return {
|
||||
'name': name_data,
|
||||
'money': money_data
|
||||
}
|
||||
|
||||
def extract_texts_from_directory(directory, image_prefix):
|
||||
"""
|
||||
直接从目录中的文本文件提取name和money文本
|
||||
作为parse_regions_file的备用方法
|
||||
|
||||
Args:
|
||||
directory: 包含文本文件的目录路径
|
||||
image_prefix: 图像前缀(如"5"表示提取5_name_*.txt和5_money_*.txt文件)
|
||||
|
||||
Returns:
|
||||
包含name和money列表的字典
|
||||
"""
|
||||
name_data = []
|
||||
money_data = []
|
||||
|
||||
try:
|
||||
# 使用glob模式直接查找匹配的文件
|
||||
name_patterns = [
|
||||
os.path.join(directory, f"{image_prefix}_name_*_texts.txt"),
|
||||
os.path.join(directory, f"{image_prefix}_rotated_name_*_texts.txt")
|
||||
]
|
||||
money_patterns = [
|
||||
os.path.join(directory, f"{image_prefix}_money_*_texts.txt"),
|
||||
os.path.join(directory, f"{image_prefix}_rotated_money_*_texts.txt")
|
||||
]
|
||||
|
||||
name_files = []
|
||||
for pattern in name_patterns:
|
||||
name_files.extend(glob.glob(pattern))
|
||||
|
||||
money_files = []
|
||||
for pattern in money_patterns:
|
||||
money_files.extend(glob.glob(pattern))
|
||||
|
||||
# 如果在当前目录找不到,检查子目录
|
||||
if not name_files and not money_files:
|
||||
image_dir = os.path.join(directory, image_prefix)
|
||||
if os.path.isdir(image_dir):
|
||||
# 同样支持旋转与非旋转的文件格式
|
||||
name_patterns = [
|
||||
os.path.join(image_dir, f"{image_prefix}_name_*_texts.txt"),
|
||||
os.path.join(image_dir, f"{image_prefix}_rotated_name_*_texts.txt")
|
||||
]
|
||||
money_patterns = [
|
||||
os.path.join(image_dir, f"{image_prefix}_money_*_texts.txt"),
|
||||
os.path.join(image_dir, f"{image_prefix}_rotated_money_*_texts.txt")
|
||||
]
|
||||
|
||||
name_files = []
|
||||
for pattern in name_patterns:
|
||||
name_files.extend(glob.glob(pattern))
|
||||
|
||||
money_files = []
|
||||
for pattern in money_patterns:
|
||||
money_files.extend(glob.glob(pattern))
|
||||
|
||||
# 定义排序函数
|
||||
def extract_index(file_path):
|
||||
try:
|
||||
# 使用split+basename更高效
|
||||
parts = os.path.basename(file_path).split('_')
|
||||
# 处理旋转图片的情况
|
||||
if 'rotated' in parts:
|
||||
# 如果是rotated文件,则索引位置会不同
|
||||
idx_pos = 3 if len(parts) >= 4 else 0
|
||||
else:
|
||||
# 非旋转图片的正常情况
|
||||
idx_pos = 2 if len(parts) >= 3 else 0
|
||||
return int(parts[idx_pos]) if idx_pos < len(parts) else 0
|
||||
except (IndexError, ValueError):
|
||||
return float('inf')
|
||||
|
||||
# 对文件列表进行排序
|
||||
name_files.sort(key=extract_index)
|
||||
money_files.sort(key=extract_index)
|
||||
|
||||
# 只有在找到文件时才记录
|
||||
if name_files or money_files:
|
||||
logger.info(f"找到 {len(name_files)} 个name文本文件和 {len(money_files)} 个money文本文件")
|
||||
|
||||
# 使用更高效的读取方式
|
||||
for file_path in name_files:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
# 一次性读取并处理
|
||||
lines = [line.strip() for line in f if line.strip()]
|
||||
name_data.extend(lines)
|
||||
logger.info(f"从 {os.path.basename(file_path)} 提取了 {len(lines)} 个名称")
|
||||
|
||||
for file_path in money_files:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
# 一次性读取并处理
|
||||
lines = [line.strip() for line in f if line.strip()]
|
||||
money_data.extend(lines)
|
||||
logger.info(f"从 {os.path.basename(file_path)} 提取了 {len(lines)} 个金额")
|
||||
|
||||
if name_data or money_data:
|
||||
logger.info(f"直接从文本文件提取: {len(name_data)}个名称, {len(money_data)}个金额")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从目录提取文本时出错: {str(e)}")
|
||||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||
|
||||
return {
|
||||
'name': name_data,
|
||||
'money': money_data
|
||||
}
|
||||
|
||||
def process_pdf(pdf_path, output_dir, confidence_threshold=0.6):
|
||||
"""
|
||||
处理单个PDF文件
|
||||
|
||||
Args:
|
||||
pdf_path: PDF文件路径
|
||||
output_dir: 输出目录
|
||||
confidence_threshold: 置信度阈值
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
try:
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 获取PDF文件名
|
||||
pdf_filename = os.path.basename(pdf_path)
|
||||
pdf_name = os.path.splitext(pdf_filename)[0]
|
||||
|
||||
logger.info(f"开始处理PDF: {pdf_path}")
|
||||
|
||||
# 获取PDF文件大小
|
||||
file_size = os.path.getsize(pdf_path) / (1024 * 1024) # 转换为MB
|
||||
logger.info(f"文件大小: {file_size:.2f} MB")
|
||||
|
||||
# 创建临时目录存放PDF
|
||||
pdf_folder = os.path.join(output_dir, "temp_pdf")
|
||||
os.makedirs(pdf_folder, exist_ok=True)
|
||||
|
||||
# 复制PDF到临时目录
|
||||
temp_pdf_path = os.path.join(pdf_folder, pdf_filename)
|
||||
shutil.copy2(pdf_path, temp_pdf_path)
|
||||
|
||||
# 记录处理开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 处理PDF文件
|
||||
stats = process_pdf_to_table(
|
||||
pdf_folder=pdf_folder,
|
||||
output_folder=output_dir,
|
||||
confidence_threshold=confidence_threshold
|
||||
)
|
||||
|
||||
# 计算处理时间
|
||||
processing_time = time.time() - start_time
|
||||
logger.info(f"处理耗时: {processing_time:.2f} 秒")
|
||||
|
||||
# 记录性能数据
|
||||
stats['performance'] = {
|
||||
'file_size_mb': round(file_size, 2),
|
||||
'processing_time_seconds': round(processing_time, 2),
|
||||
'processing_speed_mb_per_second': round(file_size / processing_time, 4) if processing_time > 0 else 0
|
||||
}
|
||||
|
||||
# 解析结果
|
||||
all_results = []
|
||||
results_folder = os.path.join(output_dir, "results")
|
||||
|
||||
if os.path.exists(results_folder):
|
||||
# 使用glob快速查找所有regions.txt文件
|
||||
region_files = glob.glob(os.path.join(results_folder, "*_regions.txt"))
|
||||
|
||||
# 按文件名中的数字进行排序
|
||||
def extract_number(file_path):
|
||||
try:
|
||||
filename = os.path.basename(file_path)
|
||||
number_str = filename.split('_')[0]
|
||||
return int(number_str)
|
||||
except (IndexError, ValueError):
|
||||
return float('inf')
|
||||
|
||||
# 按文件名中的数字从小到大排序
|
||||
region_files.sort(key=extract_number)
|
||||
|
||||
if region_files:
|
||||
logger.info(f"找到 {len(region_files)} 个区域文件用于解析")
|
||||
|
||||
# 顺序处理区域文件
|
||||
for region_file in region_files:
|
||||
image_prefix = os.path.basename(region_file).split('_')[0]
|
||||
result = parse_regions_file(region_file)
|
||||
|
||||
# 只有当解析到内容时才添加结果
|
||||
if result['name'] or result['money']:
|
||||
# 对当前图片的数据进行处理和匹配
|
||||
result = process_single_image_data(result, image_prefix)
|
||||
all_results.append(result)
|
||||
logger.info(f"从 {image_prefix} 解析到内容,并完成了匹配处理")
|
||||
else:
|
||||
# 备用方法:直接从texts.txt文件读取
|
||||
backup_result = extract_texts_from_directory(results_folder, image_prefix)
|
||||
if backup_result['name'] or backup_result['money']:
|
||||
# 对备用方法提取的数据也进行处理和匹配
|
||||
backup_result = process_single_image_data(backup_result, image_prefix)
|
||||
all_results.append(backup_result)
|
||||
logger.info(f"使用备用方法从 {image_prefix} 成功提取内容,并完成了匹配处理")
|
||||
else:
|
||||
logger.warning("未找到任何regions.txt文件")
|
||||
|
||||
# 如果没有从regions.txt解析到结果,尝试直接从子目录提取
|
||||
if not all_results:
|
||||
logger.warning("未从regions.txt文件解析到内容,尝试从子目录直接提取")
|
||||
|
||||
# 获取所有子目录并排序
|
||||
subdirs = []
|
||||
if os.path.exists(results_folder):
|
||||
subdirs = [d for d in os.listdir(results_folder)
|
||||
if os.path.isdir(os.path.join(results_folder, d))]
|
||||
|
||||
# 按目录名中的数字排序
|
||||
def extract_dir_number(dirname):
|
||||
try:
|
||||
return int(dirname)
|
||||
except ValueError:
|
||||
return float('inf')
|
||||
|
||||
subdirs.sort(key=extract_dir_number)
|
||||
|
||||
if subdirs:
|
||||
logger.info(f"找到 {len(subdirs)} 个子目录,准备提取")
|
||||
|
||||
# 使用顺序处理子目录
|
||||
ordered_results = []
|
||||
for dirname in subdirs:
|
||||
result_dir = os.path.join(results_folder, dirname)
|
||||
backup_result = extract_texts_from_directory(result_dir, dirname)
|
||||
if backup_result['name'] or backup_result['money']:
|
||||
# 对子目录数据进行处理和匹配
|
||||
backup_result = process_single_image_data(backup_result, dirname)
|
||||
ordered_results.append((int(dirname) if dirname.isdigit() else float('inf'), backup_result))
|
||||
|
||||
# 按子目录编号排序后添加到结果
|
||||
ordered_results.sort(key=lambda x: x[0])
|
||||
for _, result in ordered_results:
|
||||
all_results.append(result)
|
||||
|
||||
# 合并所有结果
|
||||
merged_result = {'name': [], 'money': [], 'name_money': []}
|
||||
|
||||
if all_results:
|
||||
# 合并所有图片的处理结果,保持每个图片内的匹配关系
|
||||
for result in all_results:
|
||||
merged_result['name'].extend(result.get('name', []))
|
||||
merged_result['money'].extend(result.get('money', []))
|
||||
merged_result['name_money'].extend(result.get('name_money', []))
|
||||
|
||||
# 注意:由于每个图片已经单独处理并匹配,此处不再需要全局匹配逻辑
|
||||
|
||||
# 记录处理结果摘要
|
||||
logger.info(f"PDF {pdf_filename} 处理完成,找到 {len(merged_result['name'])} 个名称, {len(merged_result['money'])} 个金额, {len(merged_result['name_money'])} 个名称-金额对")
|
||||
|
||||
# 保存最终结果到JSON文件
|
||||
result_json_path = os.path.join(output_dir, f"{pdf_name}_result.json")
|
||||
with open(result_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
'pdf_file': pdf_filename,
|
||||
'name': merged_result['name'],
|
||||
'money': merged_result['money'],
|
||||
'name_money': merged_result['name_money'],
|
||||
'stats': stats
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"JSON结果已保存到: {result_json_path}")
|
||||
|
||||
# 保存最终结果到TXT文件
|
||||
result_txt_path = os.path.join(output_dir, f"{pdf_name}_result.txt")
|
||||
with open(result_txt_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"PDF文件:{pdf_filename} 处理结果\n")
|
||||
f.write(f"处理时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write("=" * 80 + "\n\n")
|
||||
|
||||
# 写入name-money对
|
||||
f.write(f"共找到 {len(merged_result['name_money'])} 个名称-金额对:\n\n")
|
||||
for i, pair in enumerate(merged_result['name_money'], 1):
|
||||
f.write(f"{i}. {pair}\n")
|
||||
|
||||
f.write("\n" + "=" * 80 + "\n")
|
||||
logger.info(f"TXT结果已保存到: {result_txt_path}")
|
||||
|
||||
# 清理临时目录
|
||||
try:
|
||||
shutil.rmtree(pdf_folder)
|
||||
logger.info("已清理临时PDF目录")
|
||||
except Exception as e:
|
||||
logger.warning(f"清理临时目录时出错: {str(e)}")
|
||||
|
||||
return {
|
||||
'pdf_file': pdf_filename,
|
||||
'result': merged_result,
|
||||
'stats': stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理PDF {pdf_path} 时出错: {str(e)}")
|
||||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||
return {
|
||||
'pdf_file': os.path.basename(pdf_path),
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def process_single_image_data(result, image_prefix):
|
||||
"""
|
||||
处理单张图片的名称和金额数据,解决匹配问题
|
||||
|
||||
Args:
|
||||
result: 包含name和money的字典
|
||||
image_prefix: 图片前缀标识
|
||||
|
||||
Returns:
|
||||
处理后的结果字典,包含匹配好的name_money对
|
||||
"""
|
||||
log_verbose(f"开始处理图片 {image_prefix} 的数据匹配")
|
||||
|
||||
# 复制原始数据,避免修改原始数据
|
||||
processed_result = {
|
||||
'name': result.get('name', []).copy(),
|
||||
'money': result.get('money', []).copy(),
|
||||
'name_money': []
|
||||
}
|
||||
|
||||
# 对金额应用后处理函数 - 使用专门的后处理函数
|
||||
processed_result['money'] = post_process_money_texts(processed_result['money'])
|
||||
log_verbose(f"图片 {image_prefix} 金额后处理后数量: {len(processed_result['money'])}")
|
||||
|
||||
# 金额格式化和验证 - 确保金额格式正确
|
||||
validated_money = []
|
||||
for money in processed_result['money']:
|
||||
# 尝试清理和标准化金额格式
|
||||
cleaned_money = money.strip()
|
||||
|
||||
# 移除可能的非数字字符(保留数字、小数点和逗号)
|
||||
cleaned_money = re.sub(r'[^\d.,]', '', cleaned_money)
|
||||
|
||||
# 检查是否包含小数点,如果没有可能需要添加
|
||||
if '.' not in cleaned_money and cleaned_money.isdigit():
|
||||
cleaned_money = cleaned_money + '.00'
|
||||
log_verbose(f"格式化金额: {money} -> {cleaned_money}")
|
||||
# 确保小数点后有两位数字
|
||||
elif '.' in cleaned_money:
|
||||
parts = cleaned_money.split('.')
|
||||
if len(parts) == 2:
|
||||
integer_part, decimal_part = parts
|
||||
# 如果小数部分长度不足2,补零
|
||||
if len(decimal_part) < 2:
|
||||
decimal_part = decimal_part.ljust(2, '0')
|
||||
# 如果小数部分超过2位,可能需要截断或四舍五入(根据需求)
|
||||
elif len(decimal_part) > 2:
|
||||
decimal_part = decimal_part[:2]
|
||||
cleaned_money = integer_part + '.' + decimal_part
|
||||
log_verbose(f"标准化金额: {money} -> {cleaned_money}")
|
||||
|
||||
# 检查金额格式是否有效
|
||||
try:
|
||||
float_value = float(cleaned_money.replace(',', ''))
|
||||
if float_value < 0.01:
|
||||
logger.warning(f"忽略异常小的金额值: {cleaned_money}")
|
||||
continue
|
||||
validated_money.append(cleaned_money)
|
||||
except ValueError:
|
||||
logger.warning(f"忽略无效金额格式: {cleaned_money}")
|
||||
|
||||
# 更新为验证后的金额
|
||||
processed_result['money'] = validated_money
|
||||
|
||||
# 处理name和money数量不匹配的情况
|
||||
max_name_len = len(processed_result['name'])
|
||||
max_money_len = len(processed_result['money'])
|
||||
|
||||
# 记录name和money配对前的情况
|
||||
log_verbose(f"图片 {image_prefix} 配对前: {max_name_len}个名称, {max_money_len}个金额")
|
||||
|
||||
# 判断是否存在不匹配
|
||||
if max_name_len != max_money_len:
|
||||
logger.warning(f"图片 {image_prefix} 名称和金额数量不匹配: {max_name_len}个名称 vs {max_money_len}个金额")
|
||||
|
||||
# 处理姓名中可能存在的误识别情况
|
||||
cleaned_names = []
|
||||
for name in processed_result['name']:
|
||||
# 判断姓名中是否包含错误识别的数字和非中文字符
|
||||
clean_name = re.sub(r'[0-9a-zA-Z.,:;!@#$%^&*()_+={}\[\]|\\/<>?~`-]', '', name)
|
||||
if clean_name != name:
|
||||
log_verbose(f"清理姓名中的误识别: {name} -> {clean_name}")
|
||||
|
||||
# 如果清理后不为空,则保留
|
||||
if clean_name.strip():
|
||||
cleaned_names.append(clean_name.strip())
|
||||
else:
|
||||
logger.warning(f"姓名'{name}'在清理后为空,已移除")
|
||||
|
||||
# 更新清理后的姓名列表
|
||||
processed_result['name'] = cleaned_names
|
||||
|
||||
# 再次检查数量
|
||||
max_name_len = len(processed_result['name'])
|
||||
log_verbose(f"清理后: {max_name_len}个名称, {max_money_len}个金额")
|
||||
|
||||
# 创建最终的配对结果
|
||||
# 采用截断方式匹配,确保数据不会错位
|
||||
match_len = min(max_name_len, max_money_len)
|
||||
|
||||
if match_len > 0:
|
||||
for i in range(match_len):
|
||||
name = processed_result['name'][i]
|
||||
money = processed_result['money'][i]
|
||||
pair = f"{name}-{money}"
|
||||
processed_result['name_money'].append(pair)
|
||||
log_verbose(f"图片 {image_prefix} 匹配对 #{i+1}: {pair}")
|
||||
|
||||
# 如果有多余的名称或金额,记录日志但不匹配
|
||||
if max_name_len > match_len:
|
||||
extra_names = processed_result['name'][match_len:]
|
||||
logger.warning(f"图片 {image_prefix} 有 {len(extra_names)} 个未匹配的名称: {extra_names}")
|
||||
|
||||
if max_money_len > match_len:
|
||||
extra_money = processed_result['money'][match_len:]
|
||||
logger.warning(f"图片 {image_prefix} 有 {len(extra_money)} 个未匹配的金额: {extra_money}")
|
||||
|
||||
log_verbose(f"图片 {image_prefix} 处理完成,共生成 {len(processed_result['name_money'])} 个名称-金额对")
|
||||
|
||||
return processed_result
|
||||
|
||||
# 原始的顺序处理函数
|
||||
def process_pdfs_in_directory_sequential(input_dir, output_dir, confidence_threshold=0.6):
|
||||
"""
|
||||
顺序处理目录中的所有PDF文件
|
||||
|
||||
Args:
|
||||
input_dir: 输入目录,包含PDF文件
|
||||
output_dir: 输出目录
|
||||
confidence_threshold: 置信度阈值
|
||||
|
||||
Returns:
|
||||
处理结果统计
|
||||
"""
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 查找所有PDF文件
|
||||
pdf_files = [f for f in os.listdir(input_dir) if f.lower().endswith('.pdf')]
|
||||
|
||||
if not pdf_files:
|
||||
logger.warning(f"在 {input_dir} 目录中未找到PDF文件")
|
||||
return {'processed': 0, 'success': 0, 'failed': 0, 'files': []}
|
||||
|
||||
logger.info(f"在 {input_dir} 目录中找到 {len(pdf_files)} 个PDF文件")
|
||||
|
||||
# 按照文件名排序,确保处理顺序一致
|
||||
pdf_files.sort()
|
||||
|
||||
results = []
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
|
||||
# 顺序处理每个PDF
|
||||
for pdf_file in pdf_files:
|
||||
pdf_path = os.path.join(input_dir, pdf_file)
|
||||
pdf_name = os.path.splitext(pdf_file)[0]
|
||||
|
||||
# 为每个PDF创建单独的输出目录
|
||||
pdf_output_dir = os.path.join(output_dir, pdf_name)
|
||||
|
||||
logger.info(f"开始处理 {pdf_file} ({len(results)+1}/{len(pdf_files)})")
|
||||
|
||||
# 处理PDF
|
||||
result = process_pdf(pdf_path, pdf_output_dir, confidence_threshold)
|
||||
|
||||
if 'error' in result:
|
||||
logger.error(f"处理 {pdf_file} 失败: {result['error']}")
|
||||
failed_count += 1
|
||||
else:
|
||||
logger.info(f"成功处理 {pdf_file}")
|
||||
success_count += 1
|
||||
|
||||
results.append(result)
|
||||
|
||||
# 创建汇总报告
|
||||
summary = {
|
||||
'processed': len(pdf_files),
|
||||
'success': success_count,
|
||||
'failed': failed_count,
|
||||
'files': [r['pdf_file'] for r in results]
|
||||
}
|
||||
|
||||
# 保存汇总报告
|
||||
summary_path = os.path.join(output_dir, "processing_summary.json")
|
||||
with open(summary_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"处理完成,共处理 {len(pdf_files)} 个PDF,成功 {success_count} 个,失败 {failed_count} 个")
|
||||
logger.info(f"汇总报告已保存到: {summary_path}")
|
||||
|
||||
return summary
|
||||
|
||||
def process_pdfs_in_directory_parallel(input_dir, output_dir, confidence_threshold=0.6, max_workers=None):
|
||||
"""
|
||||
并行处理目录中的所有PDF文件
|
||||
|
||||
Args:
|
||||
input_dir: 输入目录,包含PDF文件
|
||||
output_dir: 输出目录
|
||||
confidence_threshold: 置信度阈值
|
||||
max_workers: 最大线程数,默认为CPU核心数
|
||||
|
||||
Returns:
|
||||
处理结果统计
|
||||
"""
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 设置最大线程数
|
||||
if max_workers is None:
|
||||
max_workers = DEFAULT_MAX_WORKERS
|
||||
|
||||
# 查找所有PDF文件
|
||||
pdf_files = [f for f in os.listdir(input_dir) if f.lower().endswith('.pdf')]
|
||||
|
||||
if not pdf_files:
|
||||
logger.warning(f"在 {input_dir} 目录中未找到PDF文件")
|
||||
return {'processed': 0, 'success': 0, 'failed': 0, 'files': []}
|
||||
|
||||
logger.info(f"在 {input_dir} 目录中找到 {len(pdf_files)} 个PDF文件,将使用{max_workers}个线程并行处理")
|
||||
|
||||
# 按照文件名排序,确保处理顺序一致
|
||||
pdf_files.sort()
|
||||
|
||||
# 使用线程安全的计数器和结果列表
|
||||
from threading import Lock
|
||||
results_lock = Lock()
|
||||
results = []
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
|
||||
# 定义处理单个PDF的函数
|
||||
def process_single_pdf(pdf_file):
|
||||
nonlocal success_count, failed_count
|
||||
|
||||
pdf_path = os.path.join(input_dir, pdf_file)
|
||||
pdf_name = os.path.splitext(pdf_file)[0]
|
||||
|
||||
# 为每个PDF创建单独的输出目录
|
||||
pdf_output_dir = os.path.join(output_dir, pdf_name)
|
||||
|
||||
logger.info(f"开始处理 {pdf_file}")
|
||||
|
||||
try:
|
||||
# 处理PDF
|
||||
result = process_pdf(pdf_path, pdf_output_dir, confidence_threshold)
|
||||
|
||||
# 使用线程锁更新结果
|
||||
with results_lock:
|
||||
if 'error' in result:
|
||||
logger.error(f"处理 {pdf_file} 失败: {result['error']}")
|
||||
failed_count += 1
|
||||
else:
|
||||
logger.info(f"成功处理 {pdf_file}")
|
||||
success_count += 1
|
||||
|
||||
results.append(result)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"处理 {pdf_file} 时发生异常: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 使用线程锁更新结果
|
||||
with results_lock:
|
||||
failed_count += 1
|
||||
error_result = {
|
||||
'pdf_file': pdf_file,
|
||||
'error': str(e)
|
||||
}
|
||||
results.append(error_result)
|
||||
|
||||
return error_result
|
||||
|
||||
# 使用线程池并行处理
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务
|
||||
futures = [executor.submit(process_single_pdf, pdf_file) for pdf_file in pdf_files]
|
||||
|
||||
# 等待所有任务完成
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
# 创建汇总报告
|
||||
summary = {
|
||||
'processed': len(pdf_files),
|
||||
'success': success_count,
|
||||
'failed': failed_count,
|
||||
'files': [r['pdf_file'] for r in results]
|
||||
}
|
||||
|
||||
# 保存汇总报告
|
||||
summary_path = os.path.join(output_dir, "processing_summary.json")
|
||||
with open(summary_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"处理完成,共处理 {len(pdf_files)} 个PDF,成功 {success_count} 个,失败 {failed_count} 个")
|
||||
logger.info(f"汇总报告已保存到: {summary_path}")
|
||||
|
||||
return summary
|
||||
|
||||
# 设置默认处理函数为并行版本
|
||||
process_pdfs_in_directory = process_pdfs_in_directory_parallel
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='OCR处理工具:处理PDF并提取姓名和金额')
|
||||
|
||||
# 添加参数
|
||||
parser.add_argument('--input', '-i', required=True, help='输入PDF文件路径或包含PDF文件的目录')
|
||||
parser.add_argument('--output', '-o', required=True, help='输出目录')
|
||||
parser.add_argument('--confidence', '-c', type=float, default=0.6, help='置信度阈值,默认0.6')
|
||||
parser.add_argument('--parallel', '-p', action='store_true', help='是否使用并行处理(目录模式)')
|
||||
parser.add_argument('--workers', '-w', type=int, default=None, help='并行处理的最大线程数,默认为CPU核心数')
|
||||
parser.add_argument('--verbose', '-v', action='store_true', help='启用详细日志输出')
|
||||
parser.add_argument('--use-cache', action='store_true', help='启用缓存机制,加速重复处理')
|
||||
parser.add_argument('--cache-dir', help='指定缓存目录,默认使用系统临时目录')
|
||||
parser.add_argument('--clear-cache', action='store_true', help='在处理前清空缓存目录')
|
||||
|
||||
# 解析参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置详细日志模式
|
||||
global VERBOSE_LOGGING
|
||||
if args.verbose:
|
||||
VERBOSE_LOGGING = True
|
||||
os.environ['OCR_VERBOSE_LOGGING'] = '1'
|
||||
|
||||
# 设置缓存相关环境变量
|
||||
if args.use_cache:
|
||||
os.environ['OCR_USE_CACHE'] = '1'
|
||||
else:
|
||||
os.environ['OCR_USE_CACHE'] = '0'
|
||||
|
||||
if args.cache_dir:
|
||||
os.environ['OCR_CACHE_DIR'] = args.cache_dir
|
||||
|
||||
# 清空缓存
|
||||
if args.clear_cache:
|
||||
# 缓存目录通过环境变量获取
|
||||
import tempfile
|
||||
cache_dir = os.environ.get('OCR_CACHE_DIR', os.path.join(tempfile.gettempdir(), 'ocr_cache'))
|
||||
if os.path.exists(cache_dir):
|
||||
import shutil
|
||||
shutil.rmtree(cache_dir)
|
||||
logger.info(f"已清空缓存目录: {cache_dir}")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 检查输入路径
|
||||
if not os.path.exists(args.input):
|
||||
logger.error(f"输入路径不存在: {args.input}")
|
||||
return 1
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
# 处理单个文件或目录
|
||||
if os.path.isfile(args.input):
|
||||
if not args.input.lower().endswith('.pdf'):
|
||||
logger.error(f"输入文件不是PDF: {args.input}")
|
||||
return 1
|
||||
|
||||
logger.info(f"开始处理单个PDF文件: {args.input}")
|
||||
result = process_pdf(args.input, args.output, args.confidence)
|
||||
|
||||
if 'error' in result:
|
||||
logger.error(f"处理失败: {result['error']}")
|
||||
return 1
|
||||
else:
|
||||
logger.info("处理成功")
|
||||
else:
|
||||
logger.info(f"开始处理目录中的PDF文件: {args.input}")
|
||||
|
||||
# 根据参数选择处理方式
|
||||
if args.parallel:
|
||||
summary = process_pdfs_in_directory_parallel(args.input, args.output, args.confidence, args.workers)
|
||||
else:
|
||||
summary = process_pdfs_in_directory_sequential(args.input, args.output, args.confidence)
|
||||
|
||||
if summary['success'] == 0:
|
||||
logger.warning("没有成功处理任何PDF文件")
|
||||
if summary['processed'] > 0:
|
||||
return 1
|
||||
|
||||
# 计算总处理时间
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"总处理时间: {total_time:.2f} 秒")
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
sys.exit(main())
|
||||
except Exception as e:
|
||||
logger.error(f"程序执行过程中发生错误: {str(e)}")
|
||||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||
sys.exit(1)
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import glob
|
||||
import re
|
||||
import argparse
|
||||
import shutil
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pdf_to_table import process_pdf_to_table
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def process_pdf(pdf_path, output_dir, confidence_threshold=0.6):
|
||||
"""处理单个PDF文件"""
|
||||
try:
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 获取PDF文件名
|
||||
pdf_filename = os.path.basename(pdf_path)
|
||||
pdf_name = os.path.splitext(pdf_filename)[0]
|
||||
|
||||
logger.info(f"开始处理PDF: {pdf_path}")
|
||||
|
||||
# 处理PDF
|
||||
result = process_pdf_to_table(
|
||||
pdf_folder=os.path.dirname(pdf_path),
|
||||
output_folder=output_dir,
|
||||
confidence_threshold=confidence_threshold
|
||||
)
|
||||
|
||||
# 生成JSON结果
|
||||
result_json = {
|
||||
'status': 'success',
|
||||
'pdf_file': pdf_filename,
|
||||
'result': result,
|
||||
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
# 保存结果
|
||||
result_path = os.path.join(output_dir, f"{pdf_name}_result.json")
|
||||
with open(result_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(result_json, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 输出结果路径到标准输出(供Java读取)
|
||||
print(json.dumps({
|
||||
'status': 'success',
|
||||
'result_path': result_path
|
||||
}))
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"处理PDF时出错: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
# 输出错误信息到标准输出(供Java读取)
|
||||
print(json.dumps({
|
||||
'status': 'error',
|
||||
'error': error_msg
|
||||
}))
|
||||
return 1
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 解析命令行参数
|
||||
parser = argparse.ArgumentParser(description='PDF OCR处理工具')
|
||||
parser.add_argument('--input', '-i', required=True, help='输入PDF文件路径')
|
||||
parser.add_argument('--output', '-o', required=True, help='输出目录')
|
||||
parser.add_argument('--confidence', '-c', type=float, default=0.6, help='置信度阈值')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 处理PDF
|
||||
return process_pdf(args.input, args.output, args.confidence)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
sys.exit(main())
|
||||
except Exception as e:
|
||||
logger.error(f"程序执行出错: {str(e)}")
|
||||
print(json.dumps({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
}))
|
||||
sys.exit(1)
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
pipeline_name: OCR
|
||||
|
||||
text_type: general
|
||||
|
||||
use_doc_preprocessor: False
|
||||
use_textline_orientation: False
|
||||
|
||||
SubPipelines:
|
||||
DocPreprocessor:
|
||||
pipeline_name: doc_preprocessor
|
||||
use_doc_orientation_classify: False
|
||||
use_doc_unwarping: True
|
||||
SubModules:
|
||||
DocOrientationClassify:
|
||||
module_name: doc_text_orientation
|
||||
model_name: PP-LCNet_x1_0_doc_ori
|
||||
model_dir: null
|
||||
DocUnwarping:
|
||||
module_name: image_unwarping
|
||||
model_name: UVDoc
|
||||
model_dir: null
|
||||
|
||||
SubModules:
|
||||
TextDetection:
|
||||
module_name: text_detection
|
||||
model_name: PP-OCRv4_mobile_det
|
||||
model_dir: ./model/PP-OCRv4_mobile_det
|
||||
limit_side_len: 960
|
||||
limit_type: max
|
||||
thresh: 0.3 # 0.3
|
||||
box_thresh: 0.6 # 0.6
|
||||
unclip_ratio: 2.0
|
||||
TextLineOrientation:
|
||||
module_name: textline_orientation
|
||||
model_name: PP-LCNet_x0_25_textline_ori
|
||||
model_dir: null
|
||||
batch_size: 6
|
||||
TextRecognition:
|
||||
module_name: text_recognition
|
||||
model_name: PP-OCRv4_mobile_rec
|
||||
model_dir: ./model/PP-OCRv4_mobile_rec
|
||||
batch_size: 6
|
||||
score_thresh: 0.0
|
||||
|
||||
|
|
@ -0,0 +1,890 @@
|
|||
import os
|
||||
import warnings
|
||||
import json
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from pdf2image import convert_from_path
|
||||
from ultralytics import YOLO
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
import paddlex as pdx
|
||||
import cv2
|
||||
import numpy as np
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import requests
|
||||
import zipfile
|
||||
import shutil
|
||||
import tempfile
|
||||
import platform
|
||||
import re
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 减少PaddlePaddle的日志输出
|
||||
os.environ['FLAGS_call_stack_level'] = '2'
|
||||
|
||||
# 忽略警告信息
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""获取资源绝对路径,支持开发环境和PyInstaller打包后的环境"""
|
||||
try:
|
||||
# PyInstaller creates a temp folder and stores path in _MEIPASS
|
||||
base_path = getattr(sys, '_MEIPASS', os.path.dirname(os.path.abspath(__file__)))
|
||||
return os.path.join(base_path, relative_path)
|
||||
except Exception as e:
|
||||
logger.error(f"获取资源路径时出错: {str(e)}")
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), relative_path)
|
||||
|
||||
def install_poppler_from_local(local_zip_path, install_dir=None):
|
||||
"""从本地ZIP文件安装Poppler"""
|
||||
try:
|
||||
if not os.path.exists(local_zip_path):
|
||||
logger.error(f"本地Poppler ZIP文件不存在: {local_zip_path}")
|
||||
return None
|
||||
|
||||
# 如果没有指定安装目录,使用默认目录
|
||||
if install_dir is None:
|
||||
install_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "poppler")
|
||||
|
||||
logger.info(f"开始从本地文件安装Poppler到: {install_dir}")
|
||||
|
||||
# 创建安装目录
|
||||
if not os.path.exists(install_dir):
|
||||
os.makedirs(install_dir)
|
||||
|
||||
# 创建临时目录
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
# 解压文件
|
||||
logger.info("开始解压文件...")
|
||||
with zipfile.ZipFile(local_zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
|
||||
# 移动文件到安装目录
|
||||
extracted_dir = os.path.join(temp_dir, os.path.splitext(os.path.basename(local_zip_path))[0])
|
||||
if not os.path.exists(extracted_dir):
|
||||
# 尝试查找解压后的目录
|
||||
for item in os.listdir(temp_dir):
|
||||
if item.startswith("poppler-"):
|
||||
extracted_dir = os.path.join(temp_dir, item)
|
||||
break
|
||||
|
||||
if os.path.exists(extracted_dir):
|
||||
for item in os.listdir(extracted_dir):
|
||||
src = os.path.join(extracted_dir, item)
|
||||
dst = os.path.join(install_dir, item)
|
||||
if os.path.isdir(src):
|
||||
shutil.copytree(src, dst, dirs_exist_ok=True)
|
||||
else:
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
logger.info(f"Poppler已安装到: {install_dir}")
|
||||
return install_dir
|
||||
else:
|
||||
logger.error("无法找到解压后的Poppler目录")
|
||||
return None
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从本地安装Poppler时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
def setup_poppler_path():
|
||||
"""配置Poppler路径"""
|
||||
# 检查是否已经设置了POPPLER_PATH环境变量
|
||||
if 'POPPLER_PATH' in os.environ:
|
||||
logger.info(f"使用已配置的Poppler路径: {os.environ['POPPLER_PATH']}")
|
||||
return
|
||||
|
||||
# 尝试常见安装路径
|
||||
possible_paths = [
|
||||
r"C:\Program Files\poppler\Library\bin",
|
||||
r"C:\Program Files (x86)\poppler\Library\bin",
|
||||
r"C:\poppler\Library\bin",
|
||||
r"D:\Program Files\poppler\Library\bin",
|
||||
r"D:\poppler\Library\bin",
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "poppler", "Library", "bin")
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if os.path.exists(path):
|
||||
os.environ["POPPLER_PATH"] = path
|
||||
logger.info(f"找到并设置Poppler路径: {path}")
|
||||
return
|
||||
|
||||
# 尝试从本地安装
|
||||
local_zip_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "poppler.zip")
|
||||
if os.path.exists(local_zip_path):
|
||||
logger.info("找到本地Poppler安装包,开始安装...")
|
||||
install_dir = install_poppler_from_local(local_zip_path)
|
||||
|
||||
if install_dir:
|
||||
bin_path = os.path.join(install_dir, "Library", "bin")
|
||||
os.environ["POPPLER_PATH"] = bin_path
|
||||
|
||||
# 添加到系统PATH
|
||||
if bin_path not in os.environ.get("PATH", ""):
|
||||
os.environ["PATH"] = bin_path + os.pathsep + os.environ.get("PATH", "")
|
||||
|
||||
logger.info(f"Poppler已安装并配置到: {bin_path}")
|
||||
return
|
||||
|
||||
logger.warning("未找到Poppler路径,PDF转换可能失败。请确保安装了Poppler并设置正确的路径。")
|
||||
logger.warning("警告: 未找到Poppler,这可能导致PDF转换失败")
|
||||
|
||||
# 尝试设置Poppler路径
|
||||
setup_poppler_path()
|
||||
|
||||
# 获取程序所在目录
|
||||
if getattr(sys, 'frozen', False):
|
||||
# 打包环境
|
||||
BASE_DIR = getattr(sys, '_MEIPASS', os.path.dirname(sys.executable))
|
||||
else:
|
||||
# 开发环境
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# 模型路径配置
|
||||
MODEL_DIR = resource_path(os.path.join(BASE_DIR, "model"))
|
||||
YOLO_MODEL_PATH = resource_path(os.path.join(MODEL_DIR, "best.pt"))
|
||||
LAYOUT_MODEL_DIR = resource_path(os.path.join(MODEL_DIR, "layout"))
|
||||
OCR_CONFIG_PATH = resource_path(os.path.join(BASE_DIR, "ocr.yaml"))
|
||||
|
||||
def pdf_to_images_bulk(input_folder, output_folder, dpi=300):
|
||||
"""
|
||||
使用 pdf2image 高质量转换 PDF,每张图片按顺序编号
|
||||
|
||||
Args:
|
||||
input_folder: PDF文件所在文件夹路径
|
||||
output_folder: 转换后图片保存的文件夹路径
|
||||
dpi: 分辨率,默认300
|
||||
|
||||
Returns:
|
||||
转换后的图片路径和对应的PDF来源信息列表
|
||||
"""
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
pdf_files = [f for f in os.listdir(input_folder) if f.lower().endswith('.pdf')]
|
||||
image_index = 1 # 统一编号
|
||||
image_info = [] # 存储图片信息
|
||||
|
||||
logger.info(f"开始处理 {len(pdf_files)} 个PDF文件...")
|
||||
|
||||
for pdf_file in tqdm(sorted(pdf_files), desc="PDF转图片"):
|
||||
pdf_path = os.path.join(input_folder, pdf_file)
|
||||
|
||||
try:
|
||||
# 检查环境变量中是否有Poppler路径
|
||||
poppler_path = os.environ.get("POPPLER_PATH")
|
||||
if poppler_path and os.path.exists(poppler_path):
|
||||
# 使用设置的Poppler路径
|
||||
logger.info(f"使用Poppler路径: {poppler_path}")
|
||||
logger.info(f"处理PDF文件: {pdf_path}")
|
||||
|
||||
try:
|
||||
images = convert_from_path(pdf_path, dpi=dpi, poppler_path=poppler_path)
|
||||
logger.info(f"成功转换PDF,共 {len(images)} 页")
|
||||
except Exception as e:
|
||||
logger.error(f"使用指定Poppler路径转换失败: {str(e)}")
|
||||
# 尝试不指定路径
|
||||
logger.info("尝试不指定Poppler路径进行转换")
|
||||
images = convert_from_path(pdf_path, dpi=dpi)
|
||||
else:
|
||||
# 尝试使用默认路径
|
||||
logger.warning("未找到Poppler路径,尝试使用系统默认路径")
|
||||
images = convert_from_path(pdf_path, dpi=dpi)
|
||||
logger.info("使用系统默认Poppler路径成功")
|
||||
|
||||
for i, img in enumerate(images):
|
||||
img_path = os.path.join(output_folder, f'{image_index}.png')
|
||||
img.save(img_path, 'PNG')
|
||||
|
||||
# 记录图片信息
|
||||
image_info.append({
|
||||
'image_path': img_path,
|
||||
'pdf_file': pdf_file,
|
||||
'page_number': i+1,
|
||||
'image_index': image_index
|
||||
})
|
||||
|
||||
image_index += 1
|
||||
except Exception as e:
|
||||
logger.error(f"处理 {pdf_file} 时出错: {str(e)}")
|
||||
# 打印更详细的错误信息
|
||||
import traceback
|
||||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||
|
||||
logger.info(f"PDF转图片完成,共转换 {len(image_info)} 张图片")
|
||||
return image_info
|
||||
|
||||
def classify_images(image_folder, confidence_threshold=0.6):
|
||||
"""
|
||||
使用YOLO模型对图片进行分类
|
||||
|
||||
Args:
|
||||
image_folder: 图片所在文件夹路径
|
||||
confidence_threshold: 置信度阈值
|
||||
|
||||
Returns:
|
||||
分类结果列表
|
||||
"""
|
||||
logger.info(f"加载YOLO模型: {YOLO_MODEL_PATH}")
|
||||
model = YOLO(YOLO_MODEL_PATH)
|
||||
|
||||
# 确保使用CPU推理
|
||||
model.to('cpu')
|
||||
|
||||
image_files = [f for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
||||
results = []
|
||||
|
||||
logger.info(f"开始分类 {len(image_files)} 张图片...")
|
||||
|
||||
for image_file in tqdm(sorted(image_files, key=lambda x: int(os.path.splitext(x)[0])), desc="图片分类"):
|
||||
image_path = os.path.join(image_folder, image_file)
|
||||
try:
|
||||
result = model.predict(image_path, verbose=False)[0]
|
||||
|
||||
class_id = int(result.probs.top1)
|
||||
class_name = result.names[class_id]
|
||||
confidence = float(result.probs.top1conf)
|
||||
|
||||
# 根据类别确定旋转角度和是否为目标图片
|
||||
rotation_angle = 0
|
||||
if class_name == 'yes_90':
|
||||
rotation_angle = -90 # 逆时针旋转90度
|
||||
elif class_name == 'yes_-90':
|
||||
rotation_angle = 90 # 顺时针旋转90度
|
||||
|
||||
# 确定是否为目标图片(现在有三种yes类型都是目标)
|
||||
is_target = (class_name in ['yes', 'yes_90', 'yes_-90']) and confidence >= confidence_threshold
|
||||
|
||||
results.append({
|
||||
'image_path': image_path,
|
||||
'image_name': image_file,
|
||||
'class_id': class_id,
|
||||
'class_name': class_name,
|
||||
'confidence': confidence,
|
||||
'rotation_angle': rotation_angle,
|
||||
'is_target': is_target
|
||||
})
|
||||
|
||||
logger.info(f"分类 {image_file}: 类别={class_name}, 置信度={confidence:.4f}, 旋转角度={rotation_angle}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片 {image_file} 时出错: {str(e)}")
|
||||
continue
|
||||
|
||||
logger.info(f"图片分类完成,共分类 {len(results)} 张图片")
|
||||
return results
|
||||
|
||||
def rotate_image(image_path, rotation_angle, output_folder=None):
|
||||
"""
|
||||
根据指定角度旋转图片
|
||||
|
||||
Args:
|
||||
image_path: 原始图片路径
|
||||
rotation_angle: 旋转角度(正值表示顺时针,负值表示逆时针)
|
||||
output_folder: 输出文件夹,如果为None则使用原图所在文件夹
|
||||
|
||||
Returns:
|
||||
旋转后的图片路径
|
||||
"""
|
||||
# 如果不需要旋转,直接返回原路径
|
||||
if rotation_angle == 0:
|
||||
return image_path
|
||||
|
||||
# 确定输出路径
|
||||
image_name = os.path.basename(image_path)
|
||||
base_name, ext = os.path.splitext(image_name)
|
||||
|
||||
if output_folder is None:
|
||||
output_folder = os.path.dirname(image_path)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# 构建旋转后图片的文件名
|
||||
rotated_image_path = os.path.join(output_folder, f"{base_name}_rotated{ext}")
|
||||
|
||||
# 读取图片
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
logger.error(f"无法读取图片: {image_path}")
|
||||
return image_path
|
||||
|
||||
# 获取图片尺寸
|
||||
height, width = img.shape[:2]
|
||||
|
||||
# 计算旋转矩阵
|
||||
if rotation_angle == 90 or rotation_angle == -90:
|
||||
# 对于90度的旋转,交换宽高
|
||||
if rotation_angle == 90: # 顺时针90度
|
||||
rotated_img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
||||
else: # 逆时针90度
|
||||
rotated_img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
else:
|
||||
# 其他角度的旋转
|
||||
center = (width // 2, height // 2)
|
||||
rotation_matrix = cv2.getRotationMatrix2D(center, rotation_angle, 1.0)
|
||||
|
||||
# 计算旋转后的图像大小
|
||||
abs_cos = abs(rotation_matrix[0, 0])
|
||||
abs_sin = abs(rotation_matrix[0, 1])
|
||||
bound_w = int(height * abs_sin + width * abs_cos)
|
||||
bound_h = int(height * abs_cos + width * abs_sin)
|
||||
|
||||
# 调整旋转矩阵
|
||||
rotation_matrix[0, 2] += bound_w / 2 - center[0]
|
||||
rotation_matrix[1, 2] += bound_h / 2 - center[1]
|
||||
|
||||
# 执行旋转
|
||||
rotated_img = cv2.warpAffine(img, rotation_matrix, (bound_w, bound_h),
|
||||
flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(255, 255, 255))
|
||||
|
||||
# 保存旋转后的图片
|
||||
cv2.imwrite(rotated_image_path, rotated_img)
|
||||
logger.info(f"图片已旋转 {rotation_angle} 度并保存至: {rotated_image_path}")
|
||||
|
||||
return rotated_image_path
|
||||
|
||||
def detect_layout_and_ocr(image_path, layout_module, ocr_pipeline, output_dir=None):
|
||||
"""
|
||||
对图片进行版面检测和OCR识别
|
||||
|
||||
Args:
|
||||
image_path: 图片路径
|
||||
layout_module: 已加载的版面检测模型
|
||||
ocr_pipeline: 已加载的OCR产线
|
||||
output_dir: 输出目录,如果为None则自动创建
|
||||
|
||||
Returns:
|
||||
所有识别到的目标区域文本列表
|
||||
"""
|
||||
# 创建输出目录
|
||||
if output_dir is None:
|
||||
output_dir = os.path.join("results", "layout_ocr", os.path.splitext(os.path.basename(image_path))[0])
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 获取图片文件名(不含扩展名)
|
||||
img_name = os.path.splitext(os.path.basename(image_path))[0]
|
||||
|
||||
# 执行版面检测
|
||||
results_generator = layout_module.predict(image_path, layout_nms=True)
|
||||
|
||||
# 使用字典存储不同类别的区域
|
||||
all_regions = {'name': [], 'money': [], 'title': []}
|
||||
target_regions = [] # 存储最终需要处理的区域(name和money)
|
||||
result_count = 0
|
||||
|
||||
# 为每个结果创建目录
|
||||
for result_count, result in enumerate(results_generator, 1):
|
||||
# 保存结果为JSON
|
||||
json_path = os.path.join(output_dir, f"{img_name}_result_{result_count}.json")
|
||||
result.save_to_json(json_path)
|
||||
|
||||
# 从JSON读取检测结果
|
||||
try:
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
result_data = json.load(f)
|
||||
|
||||
# 从result_data中提取目标区域
|
||||
# 直接从JSON根级别获取boxes字段
|
||||
if 'boxes' in result_data:
|
||||
boxes = result_data['boxes']
|
||||
logger.debug(f"从JSON中检测到 {len(boxes)} 个区域")
|
||||
|
||||
# 遍历所有检测到的区域
|
||||
for box in boxes:
|
||||
label = box.get('label')
|
||||
score = box.get('score')
|
||||
coordinate = box.get('coordinate')
|
||||
|
||||
# 根据类别分类区域
|
||||
if label in all_regions:
|
||||
all_regions[label].append({
|
||||
'coordinate': coordinate,
|
||||
'score': score
|
||||
})
|
||||
|
||||
# 如果是name或money区域,添加到目标区域列表
|
||||
if label in ['name', 'money']:
|
||||
target_regions.append({
|
||||
'label': label,
|
||||
'coordinate': coordinate,
|
||||
'score': score
|
||||
})
|
||||
logger.debug(f"从JSON中找到{label}区域: {coordinate}, 置信度: {score:.4f}")
|
||||
|
||||
# 记录统计
|
||||
for label, regions in all_regions.items():
|
||||
if regions:
|
||||
logger.info(f"检测到 {len(regions)} 个 {label} 区域")
|
||||
|
||||
logger.info(f"共有 {len(target_regions)} 个目标区域(name+money)需要OCR处理")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析JSON结果时出错: {str(e)}")
|
||||
|
||||
# 如果仍找不到目标区域,返回空列表
|
||||
if not target_regions:
|
||||
logger.info(f"图片 {img_name} 未找到name或money区域")
|
||||
return []
|
||||
|
||||
# 绘制和保存所有目标区域的可视化
|
||||
try:
|
||||
img = cv2.imread(image_path)
|
||||
if img is not None:
|
||||
vis_img = img.copy()
|
||||
|
||||
# 为不同类别设置不同颜色
|
||||
color_map = {
|
||||
'name': (0, 255, 0), # 绿色
|
||||
'money': (0, 0, 255), # 红色
|
||||
'title': (255, 0, 0) # 蓝色
|
||||
}
|
||||
|
||||
# 绘制所有区域,包括title(用于参考)
|
||||
for label, regions in all_regions.items():
|
||||
for idx, region_info in enumerate(regions):
|
||||
coordinate = region_info['coordinate']
|
||||
x1, y1, x2, y2 = map(int, coordinate)
|
||||
color = color_map.get(label, (128, 128, 128)) # 默认灰色
|
||||
cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.putText(vis_img, f"{label} {idx+1}", (x1, y1-10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||
|
||||
# 保存结果图
|
||||
vis_img_path = os.path.join(output_dir, f"{img_name}_all_regions.jpg")
|
||||
cv2.imwrite(vis_img_path, vis_img)
|
||||
except Exception as e:
|
||||
logger.error(f"生成区域可视化时出错: {str(e)}")
|
||||
|
||||
# 对每个目标区域进行OCR识别
|
||||
all_region_texts = [] # 存储所有目标区域的文本内容
|
||||
|
||||
for idx, region_info in enumerate(target_regions):
|
||||
region_label = region_info['label']
|
||||
bbox = region_info['coordinate']
|
||||
logger.info(f"处理{region_label}区域 {idx+1}...")
|
||||
region_texts = [] # 存储当前区域的所有文本
|
||||
|
||||
try:
|
||||
# 读取原图并裁剪区域
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
img_height, img_width = img.shape[:2]
|
||||
|
||||
# 根据不同区域类型设置不同的裁剪边距
|
||||
# 只缩小左右距离,保持上下不变
|
||||
if region_label == 'name':
|
||||
# name区域裁剪参数(缩小左右边距,保持上下不变)
|
||||
left_padding = 10 # 左侧缩小像素
|
||||
right_padding = 10 # 右侧缩小像素
|
||||
top_padding = 0 # 上侧保持不变
|
||||
bottom_padding = -7 # 下侧保持不变
|
||||
elif region_label == 'money':
|
||||
# money区域裁剪参数(更大程度缩小左右边距,保持上下不变)
|
||||
left_padding = 20 # 左侧缩小更多像素
|
||||
right_padding = 20 # 右侧缩小更多像素
|
||||
top_padding = 0 # 上侧保持不变
|
||||
bottom_padding = 0 # 下侧保持不变
|
||||
else:
|
||||
# 其他区域使用默认参数
|
||||
left_padding = -7
|
||||
right_padding = -7
|
||||
top_padding = -7
|
||||
bottom_padding = -7
|
||||
|
||||
x1, y1, x2, y2 = map(int, bbox)
|
||||
|
||||
# 应用裁剪边距
|
||||
x1 = max(0, x1 + left_padding)
|
||||
y1 = max(0, y1 + top_padding)
|
||||
x2 = min(img_width, x2 - right_padding)
|
||||
y2 = min(img_height, y2 - bottom_padding)
|
||||
|
||||
region_img = img[y1:y2, x1:x2].copy()
|
||||
region_img_path = os.path.join(output_dir, f"{img_name}_{region_label}_{idx+1}.jpg")
|
||||
cv2.imwrite(region_img_path, region_img)
|
||||
|
||||
# 执行OCR识别
|
||||
ocr_results = ocr_pipeline.predict(region_img_path)
|
||||
|
||||
# 处理OCR结果
|
||||
for res_idx, ocr_res in enumerate(ocr_results, 1):
|
||||
# 保存JSON结果
|
||||
json_path = os.path.join(output_dir, f"{img_name}_{region_label}_{idx+1}_ocr_{res_idx}.json")
|
||||
ocr_res.save_to_json(save_path=json_path)
|
||||
|
||||
# 读取JSON提取文本
|
||||
try:
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
# 检查是否包含rec_texts字段(顶层)
|
||||
if 'rec_texts' in json_data and isinstance(json_data['rec_texts'], list):
|
||||
rec_texts = json_data['rec_texts']
|
||||
region_texts.extend(rec_texts)
|
||||
# 检查列表中的元素
|
||||
elif isinstance(json_data, list):
|
||||
for item in json_data:
|
||||
if isinstance(item, dict):
|
||||
if 'rec_texts' in item:
|
||||
# 如果是rec_texts数组
|
||||
if isinstance(item['rec_texts'], list):
|
||||
region_texts.extend(item['rec_texts'])
|
||||
else:
|
||||
region_texts.append(item['rec_texts'])
|
||||
elif 'text' in item:
|
||||
# 兼容旧格式
|
||||
region_texts.append(item['text'])
|
||||
|
||||
# 如果JSON直接为空或不包含需要的字段,尝试其他方法
|
||||
if not region_texts and hasattr(ocr_res, 'rec_texts'):
|
||||
region_texts.extend(ocr_res.rec_texts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"读取OCR JSON结果时出错: {str(e)}")
|
||||
|
||||
# 尝试直接从对象提取文本信息
|
||||
try:
|
||||
if hasattr(ocr_res, 'rec_texts'):
|
||||
region_texts.extend(ocr_res.rec_texts)
|
||||
except:
|
||||
logger.error("无法从对象中直接提取文本")
|
||||
|
||||
# 添加区域标签和文本到总列表
|
||||
all_region_texts.append({
|
||||
'label': region_label,
|
||||
'texts': region_texts
|
||||
})
|
||||
|
||||
logger.info(f"{region_label}区域{idx+1}提取到{len(region_texts)}个文本")
|
||||
|
||||
# 保存当前区域文本
|
||||
txt_path = os.path.join(output_dir, f"{img_name}_{region_label}_{idx+1}_texts.txt")
|
||||
with open(txt_path, 'w', encoding='utf-8') as f:
|
||||
for text in region_texts:
|
||||
f.write(f"{text}\n")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理{region_label}区域{idx+1}时出错: {str(e)}")
|
||||
continue
|
||||
|
||||
return all_region_texts
|
||||
|
||||
def merge_region_texts(all_region_texts, output_file):
|
||||
"""
|
||||
合并所有区域内容为一个txt文件,采用更易于阅读的格式
|
||||
|
||||
Args:
|
||||
all_region_texts: 区域文本列表,每个元素包含label和texts字段
|
||||
output_file: 输出文件路径
|
||||
"""
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(f"区域识别结果 - 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write("=" * 80 + "\n\n")
|
||||
|
||||
# 按类别分组输出
|
||||
for category in ['name', 'money']:
|
||||
category_regions = [r for r in all_region_texts if r['label'] == category]
|
||||
|
||||
if category_regions:
|
||||
f.write(f"{category.upper()} 区域 (共{len(category_regions)}个):\n")
|
||||
f.write("-" * 50 + "\n\n")
|
||||
|
||||
for region_idx, region_data in enumerate(category_regions, 1):
|
||||
texts = region_data['texts']
|
||||
f.write(f"{category} {region_idx} (共{len(texts)}项):\n")
|
||||
|
||||
# 写入区域内容,每个条目带编号
|
||||
for text_idx, text in enumerate(texts, 1):
|
||||
f.write(f" {text_idx}. {text}\n")
|
||||
|
||||
# 如果不是最后一个区域,添加分隔符
|
||||
if region_idx < len(category_regions):
|
||||
f.write("\n" + "-" * 30 + "\n\n")
|
||||
|
||||
# 在类别之间添加分隔符
|
||||
f.write("\n" + "=" * 50 + "\n\n")
|
||||
|
||||
logger.info(f"所有区域文本已合并保存至: {output_file}")
|
||||
return output_file
|
||||
|
||||
def post_process_money_texts(money_texts):
|
||||
"""
|
||||
处理money文本列表,合并被分割的金额
|
||||
|
||||
Args:
|
||||
money_texts: 金额文本列表
|
||||
|
||||
Returns:
|
||||
处理后的金额文本列表
|
||||
"""
|
||||
if not money_texts:
|
||||
return money_texts
|
||||
|
||||
# 创建新的结果列表
|
||||
processed_texts = []
|
||||
skip_next = False
|
||||
|
||||
for i in range(len(money_texts)):
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
continue
|
||||
|
||||
current_text = money_texts[i].strip()
|
||||
|
||||
# 检查当前文本是否以小数点结尾或包含小数点后缺少完整的分位
|
||||
if i < len(money_texts) - 1:
|
||||
next_text = money_texts[i+1].strip()
|
||||
needs_merge = False
|
||||
|
||||
# 情况1: 当前金额以小数点结尾
|
||||
if current_text.endswith('.'):
|
||||
needs_merge = True
|
||||
# 情况2: 当前金额包含小数点但后面不足两位数字
|
||||
elif '.' in current_text:
|
||||
decimal_part = current_text.split('.')[-1]
|
||||
if len(decimal_part) < 2:
|
||||
needs_merge = True
|
||||
# 情况3: 下一个值是纯数字且长度小于等于2
|
||||
is_next_decimal = next_text.isdigit() and len(next_text) <= 2
|
||||
# 情况4: 下一个值以0开头且是纯数字(可能是小数部分)
|
||||
is_next_leading_zero = next_text.startswith('0') and next_text.replace('0', '').isdigit()
|
||||
# 情况5: 当前文本是纯数字,下一个文本以小数点开头
|
||||
is_current_integer = current_text.isdigit() and next_text.startswith('.')
|
||||
|
||||
if (needs_merge and (is_next_decimal or is_next_leading_zero)) or is_current_integer:
|
||||
# 合并当前金额和下一个金额
|
||||
merged_value = current_text + next_text
|
||||
processed_texts.append(merged_value)
|
||||
logger.info(f"合并金额: {current_text} + {next_text} = {merged_value}")
|
||||
skip_next = True
|
||||
continue
|
||||
|
||||
# 如果没有合并,直接添加当前金额
|
||||
processed_texts.append(current_text)
|
||||
|
||||
# 进行金额格式化
|
||||
formatted_texts = []
|
||||
for money in processed_texts:
|
||||
# 尝试清理和标准化金额格式
|
||||
cleaned_money = money.strip()
|
||||
|
||||
# 移除非数字、小数点和逗号
|
||||
cleaned_money = re.sub(r'[^\d.,]', '', cleaned_money)
|
||||
|
||||
# 检查是否包含小数点,如果没有可能需要添加
|
||||
if '.' not in cleaned_money and cleaned_money.isdigit():
|
||||
cleaned_money = cleaned_money + '.00'
|
||||
logger.info(f"格式化金额: {money} -> {cleaned_money}")
|
||||
# 确保小数点后有两位数字
|
||||
elif '.' in cleaned_money:
|
||||
parts = cleaned_money.split('.')
|
||||
if len(parts) == 2:
|
||||
integer_part, decimal_part = parts
|
||||
# 如果小数部分长度不足2,补零
|
||||
if len(decimal_part) < 2:
|
||||
decimal_part = decimal_part.ljust(2, '0')
|
||||
# 如果小数部分超过2位,截断
|
||||
elif len(decimal_part) > 2:
|
||||
decimal_part = decimal_part[:2]
|
||||
cleaned_money = integer_part + '.' + decimal_part
|
||||
logger.info(f"标准化金额: {money} -> {cleaned_money}")
|
||||
|
||||
# 尝试移除可能被误识别的文本前缀或后缀
|
||||
# 查找第一个数字的位置
|
||||
first_digit_pos = -1
|
||||
for pos, char in enumerate(cleaned_money):
|
||||
if char.isdigit() or char == '.':
|
||||
first_digit_pos = pos
|
||||
break
|
||||
|
||||
if first_digit_pos > 0:
|
||||
cleaned_money = cleaned_money[first_digit_pos:]
|
||||
logger.info(f"移除金额前缀: {money} -> {cleaned_money}")
|
||||
|
||||
# 只保留有效的金额
|
||||
try:
|
||||
float_value = float(cleaned_money.replace(',', ''))
|
||||
formatted_texts.append(cleaned_money)
|
||||
except ValueError:
|
||||
logger.warning(f"无效金额格式,已忽略: {money}")
|
||||
|
||||
return formatted_texts
|
||||
|
||||
def process_pdf_to_table(pdf_folder, output_folder, confidence_threshold=0.6):
|
||||
"""
|
||||
端到端处理PDF到区域文本的整个流程
|
||||
|
||||
Args:
|
||||
pdf_folder: PDF文件所在文件夹
|
||||
output_folder: 输出文件夹
|
||||
confidence_threshold: 置信度阈值
|
||||
|
||||
Returns:
|
||||
处理结果统计信息
|
||||
"""
|
||||
# 创建输出目录
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
images_folder = os.path.join(output_folder, "images")
|
||||
results_folder = os.path.join(output_folder, "results")
|
||||
rotated_folder = os.path.join(output_folder, "rotated_images")
|
||||
os.makedirs(images_folder, exist_ok=True)
|
||||
os.makedirs(results_folder, exist_ok=True)
|
||||
os.makedirs(rotated_folder, exist_ok=True)
|
||||
|
||||
# 1. PDF转图片
|
||||
logger.info("第1步: PDF转图片")
|
||||
image_info = pdf_to_images_bulk(pdf_folder, images_folder)
|
||||
|
||||
# 2. 图片分类
|
||||
logger.info("第2步: 图片分类")
|
||||
classification_results = classify_images(images_folder, confidence_threshold)
|
||||
|
||||
# 3. 筛选目标图片
|
||||
target_images = [result for result in classification_results if result['is_target']]
|
||||
logger.info(f"筛选出 {len(target_images)} 张目标图片")
|
||||
|
||||
# 保存分类结果
|
||||
classification_csv = os.path.join(results_folder, "classification_results.csv")
|
||||
pd.DataFrame(classification_results).to_csv(classification_csv, index=False, encoding='utf-8-sig')
|
||||
logger.info(f"分类结果已保存至: {classification_csv}")
|
||||
|
||||
# 4. 图像旋转处理
|
||||
logger.info("第3步: 图像旋转处理")
|
||||
for idx, image_info in enumerate(target_images):
|
||||
if image_info['rotation_angle'] != 0:
|
||||
original_path = image_info['image_path']
|
||||
rotated_path = rotate_image(
|
||||
original_path,
|
||||
image_info['rotation_angle'],
|
||||
rotated_folder
|
||||
)
|
||||
image_info['original_path'] = original_path
|
||||
image_info['image_path'] = rotated_path
|
||||
logger.info(f"图片 {image_info['image_name']} 已旋转 {image_info['rotation_angle']} 度")
|
||||
|
||||
# 创建版面检测模型
|
||||
logger.info("创建版面检测模型")
|
||||
try:
|
||||
layout_module = pdx.create_model('PP-DocLayout-L', LAYOUT_MODEL_DIR)
|
||||
logger.info(f"使用本地版面检测模型: {LAYOUT_MODEL_DIR}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载本地版面检测模型失败: {str(e)}")
|
||||
raise
|
||||
|
||||
# 创建OCR产线
|
||||
logger.info("创建OCR产线")
|
||||
try:
|
||||
ocr_pipeline = pdx.create_pipeline(pipeline=OCR_CONFIG_PATH)
|
||||
except Exception as e:
|
||||
logger.error(f"创建OCR产线失败: {str(e)}")
|
||||
raise
|
||||
|
||||
# 5. 对每张目标图片进行版面检测和OCR识别
|
||||
logger.info("第4步: 版面检测和OCR识别")
|
||||
processing_results = []
|
||||
|
||||
for idx, image_info in enumerate(tqdm(target_images, desc="处理目标图片")):
|
||||
image_path = image_info['image_path']
|
||||
image_name = image_info['image_name']
|
||||
rotation_angle = image_info.get('rotation_angle', 0)
|
||||
|
||||
# 创建图片特定的输出目录
|
||||
image_output_dir = os.path.join(results_folder, os.path.splitext(image_name)[0])
|
||||
|
||||
# 执行版面检测和OCR
|
||||
region_texts = detect_layout_and_ocr(image_path, layout_module, ocr_pipeline, image_output_dir)
|
||||
|
||||
# 如果找到目标区域,合并并保存结果
|
||||
if region_texts:
|
||||
# 后处理money文本,合并被分割的金额
|
||||
for i, region in enumerate(region_texts):
|
||||
if region['label'] == 'money':
|
||||
region_texts[i]['texts'] = post_process_money_texts(region['texts'])
|
||||
logger.info(f"处理后的金额数量: {len(region_texts[i]['texts'])}")
|
||||
|
||||
merged_txt_path = os.path.join(results_folder, f"{os.path.splitext(image_name)[0]}_regions.txt")
|
||||
merge_region_texts(region_texts, merged_txt_path)
|
||||
|
||||
# 统计各类别区域数量
|
||||
name_regions = sum(1 for r in region_texts if r['label'] == 'name')
|
||||
money_regions = sum(1 for r in region_texts if r['label'] == 'money')
|
||||
total_texts = sum(len(r['texts']) for r in region_texts)
|
||||
|
||||
processing_results.append({
|
||||
'image_path': image_path,
|
||||
'image_name': image_name,
|
||||
'rotation_angle': rotation_angle,
|
||||
'name_regions': name_regions,
|
||||
'money_regions': money_regions,
|
||||
'total_regions': len(region_texts),
|
||||
'total_texts': total_texts,
|
||||
'output_file': merged_txt_path
|
||||
})
|
||||
else:
|
||||
logger.info(f"图片 {image_name} 未找到目标区域")
|
||||
processing_results.append({
|
||||
'image_path': image_path,
|
||||
'image_name': image_name,
|
||||
'rotation_angle': rotation_angle,
|
||||
'name_regions': 0,
|
||||
'money_regions': 0,
|
||||
'total_regions': 0,
|
||||
'total_texts': 0,
|
||||
'output_file': None
|
||||
})
|
||||
|
||||
# 保存处理结果
|
||||
processing_csv = os.path.join(results_folder, "processing_results.csv")
|
||||
pd.DataFrame(processing_results).to_csv(processing_csv, index=False, encoding='utf-8-sig')
|
||||
logger.info(f"处理结果已保存至: {processing_csv}")
|
||||
|
||||
# 统计信息
|
||||
stats = {
|
||||
'total_images': len(image_info),
|
||||
'target_images': len(target_images),
|
||||
'rotated_images': sum(1 for img in target_images if img.get('rotation_angle', 0) != 0),
|
||||
'successful_extractions': sum(1 for r in processing_results if r['total_regions'] > 0),
|
||||
'total_name_regions': sum(r['name_regions'] for r in processing_results),
|
||||
'total_money_regions': sum(r['money_regions'] for r in processing_results),
|
||||
'total_text_elements': sum(r['total_texts'] for r in processing_results)
|
||||
}
|
||||
|
||||
logger.info("处理统计")
|
||||
logger.info(f"总图片数: {stats['total_images']}")
|
||||
logger.info(f"目标图片数: {stats['target_images']}")
|
||||
logger.info(f"旋转处理的图片数: {stats['rotated_images']}")
|
||||
logger.info(f"成功提取区域的图片数: {stats['successful_extractions']}")
|
||||
logger.info(f"总name区域数: {stats['total_name_regions']}")
|
||||
logger.info(f"总money区域数: {stats['total_money_regions']}")
|
||||
logger.info(f"总文本元素数: {stats['total_text_elements']}")
|
||||
|
||||
return stats
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例使用
|
||||
try:
|
||||
stats = process_pdf_to_table(
|
||||
pdf_folder="E:\ocr_cpu\pdf", # 输入PDF文件夹
|
||||
output_folder="output", # 输出文件夹
|
||||
confidence_threshold=0.6 # 置信度阈值
|
||||
)
|
||||
print("处理完成!")
|
||||
print(f"目标图片数: {stats['target_images']}")
|
||||
print(f"成功提取区域的图片数: {stats['successful_extractions']}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
|
||||
sys.exit(1)
|
||||
Loading…
Reference in New Issue