Intention/generated_data/按比例分配ernie数据.py

83 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import random
# 目录路径
directory = "output/ernie"
# 确保目录存在
os.makedirs(directory, exist_ok=True)
# 读取 JSON 文件
def load_json(file_path):
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except FileNotFoundError:
print(f"错误:文件 {file_path} 不存在!")
return []
except json.JSONDecodeError:
print(f"错误:无法解析 {file_path},请检查 JSON 格式!")
return []
# 转换数据格式:仅保留 text 和 label从 prompt 复制)
def convert_data_format(data):
converted_list = []
for item in data:
if "text" in item and "label" in item:
converted_list.append({
"text": item["text"],
"label": item["label"] # prompt→ label
})
if "text" in item and "prompt" in item:
converted_list.append({
"text": item["text"],
"label": item["prompt"] # prompt → label
})
return converted_list
# 随机按 7:3 比例分割 JSON 数据
def split_json_random(input_file, output_file1, output_file2):
# 读取数据
data = load_json(input_file)
# 确保数据不为空
if not data:
print("错误:输入文件为空或无法解析,未执行分割。")
return
# 转换数据格式
converted_data = convert_data_format(data)
# 打乱数据顺序
random.shuffle(converted_data)
# 计算数据的分割点7:3
split_point = int(len(converted_data) * 0.7)
# 按比例分割数据
data_part1 = converted_data[:split_point] # 70% 训练数据
data_part2 = converted_data[split_point:] # 30% 验证数据
# 保存数据到两个文件
with open(output_file1, 'w', encoding='utf-8') as f1:
json.dump(data_part1, f1, ensure_ascii=False, indent=4)
with open(output_file2, 'w', encoding='utf-8') as f2:
json.dump(data_part2, f2, ensure_ascii=False, indent=4)
print(f"数据已随机打乱并按 7:3 分割,保存至:\n - {output_file1}{len(data_part1)} 条)\n - {output_file2}{len(data_part2)} 条)")
# 输入的 JSON 文件路径
input_file = 'output/merged_data.json'
# 输出的两个文件路径
output_file1 = 'output/ernie/train.json'
output_file2 = 'output/ernie/val.json'
# 执行数据转换和随机分割
split_json_random(input_file, output_file1, output_file2)