2025-03-03 13:21:11 +08:00
|
|
|
|
import json
|
|
|
|
|
|
import os
|
2025-03-16 14:40:56 +08:00
|
|
|
|
import random
|
2025-03-03 13:21:11 +08:00
|
|
|
|
|
|
|
|
|
|
# 目录路径
|
|
|
|
|
|
directory = "output/ernie"
|
|
|
|
|
|
|
|
|
|
|
|
# 确保目录存在
|
|
|
|
|
|
os.makedirs(directory, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 读取 JSON 文件
|
|
|
|
|
|
def load_json(file_path):
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
return json.load(f)
|
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
|
print(f"错误:文件 {file_path} 不存在!")
|
|
|
|
|
|
return []
|
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
|
print(f"错误:无法解析 {file_path},请检查 JSON 格式!")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 转换数据格式:仅保留 text 和 label(从 prompt 复制)
|
|
|
|
|
|
def convert_data_format(data):
|
|
|
|
|
|
converted_list = []
|
|
|
|
|
|
for item in data:
|
2025-03-16 14:40:56 +08:00
|
|
|
|
if "text" in item and "label" in item:
|
2025-03-03 13:21:11 +08:00
|
|
|
|
converted_list.append({
|
|
|
|
|
|
"text": item["text"],
|
2025-03-16 14:40:56 +08:00
|
|
|
|
"label": item["label"] # prompt → label
|
2025-03-03 13:21:11 +08:00
|
|
|
|
})
|
|
|
|
|
|
return converted_list
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-03-16 14:40:56 +08:00
|
|
|
|
# 随机按 7:3 比例分割 JSON 数据
|
|
|
|
|
|
def split_json_random(input_file, output_file1, output_file2):
|
2025-03-03 13:21:11 +08:00
|
|
|
|
# 读取数据
|
|
|
|
|
|
data = load_json(input_file)
|
|
|
|
|
|
|
|
|
|
|
|
# 确保数据不为空
|
|
|
|
|
|
if not data:
|
|
|
|
|
|
print("错误:输入文件为空或无法解析,未执行分割。")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 转换数据格式
|
|
|
|
|
|
converted_data = convert_data_format(data)
|
|
|
|
|
|
|
2025-03-16 14:40:56 +08:00
|
|
|
|
# 打乱数据顺序
|
|
|
|
|
|
random.shuffle(converted_data)
|
|
|
|
|
|
|
2025-03-03 13:21:11 +08:00
|
|
|
|
# 计算数据的分割点(7:3)
|
2025-03-16 14:40:56 +08:00
|
|
|
|
split_point = int(len(converted_data) * 0.7)
|
2025-03-03 13:21:11 +08:00
|
|
|
|
|
|
|
|
|
|
# 按比例分割数据
|
2025-03-16 14:40:56 +08:00
|
|
|
|
data_part1 = converted_data[:split_point] # 70% 训练数据
|
|
|
|
|
|
data_part2 = converted_data[split_point:] # 30% 验证数据
|
2025-03-03 13:21:11 +08:00
|
|
|
|
|
|
|
|
|
|
# 保存数据到两个文件
|
|
|
|
|
|
with open(output_file1, 'w', encoding='utf-8') as f1:
|
|
|
|
|
|
json.dump(data_part1, f1, ensure_ascii=False, indent=4)
|
|
|
|
|
|
|
|
|
|
|
|
with open(output_file2, 'w', encoding='utf-8') as f2:
|
|
|
|
|
|
json.dump(data_part2, f2, ensure_ascii=False, indent=4)
|
|
|
|
|
|
|
2025-03-16 14:40:56 +08:00
|
|
|
|
print(f"数据已随机打乱并按 7:3 分割,保存至:\n - {output_file1}({len(data_part1)} 条)\n - {output_file2}({len(data_part2)} 条)")
|
2025-03-03 13:21:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 输入的 JSON 文件路径
|
|
|
|
|
|
input_file = 'output/merged_data.json'
|
|
|
|
|
|
# 输出的两个文件路径
|
|
|
|
|
|
output_file1 = 'output/ernie/train.json'
|
|
|
|
|
|
output_file2 = 'output/ernie/val.json'
|
|
|
|
|
|
|
2025-03-16 14:40:56 +08:00
|
|
|
|
# 执行数据转换和随机分割
|
|
|
|
|
|
split_json_random(input_file, output_file1, output_file2)
|