131 lines
5.2 KiB
Python
131 lines
5.2 KiB
Python
import os
|
||
from ultralytics import YOLO
|
||
import cv2
|
||
import random
|
||
|
||
def generate_label_colors(labels):
|
||
"""为每个标签生成固定颜色(红色,黄色,蓝色,绿色)"""
|
||
colors = {
|
||
"red": (255, 0, 0),
|
||
"yellow": (255, 255, 0),
|
||
"blue": (0, 0, 255),
|
||
"green": (0, 255, 0),
|
||
}
|
||
# 为标签分配颜色,确保颜色在红、黄、蓝、绿之间循环
|
||
label_colors = {}
|
||
color_list = list(colors.values())
|
||
for i, label in enumerate(labels):
|
||
print(f"generate_label_colors, label: {label}")
|
||
label_colors[label] = color_list[i % len(color_list)] # 循环分配颜色
|
||
|
||
return label_colors
|
||
|
||
def intelligent_annotation_with_filter(model_path, input_folder, output_folder, target_labels, conf_threshold=0.3):
|
||
"""
|
||
使用 YOLOv8 模型对本地文件夹图像进行智能标注,并按指定标签过滤输出。
|
||
|
||
:param model_path: YOLOv8 模型文件路径
|
||
:param input_folder: 输入文件夹路径,包含待标注的图像
|
||
:param output_folder: 输出文件夹路径,存储标注后的图像
|
||
:param target_labels: 用户指定需要标注的目标类别列表
|
||
:param conf_threshold: 置信度阈值,默认 0.5
|
||
"""
|
||
# 加载 YOLO 模型
|
||
model = YOLO(model_path)
|
||
|
||
# 获取所有可能的标签
|
||
all_labels = model.names.values()
|
||
print("所有可能的标签:", all_labels)
|
||
label_colors = generate_label_colors(all_labels)
|
||
print("所有可能的标签对应的颜色:", label_colors)
|
||
|
||
# 保存标注结果的字典
|
||
annotation_results = {}
|
||
|
||
# 创建输出文件夹(如果不存在)
|
||
os.makedirs(output_folder, exist_ok=True)
|
||
|
||
# 遍历输入文件夹中的所有图像
|
||
for file_name in os.listdir(input_folder):
|
||
input_path = os.path.join(input_folder, file_name)
|
||
|
||
# 检查文件是否是图像格式
|
||
if not (file_name.endswith('.jpg') or file_name.endswith('.png') or file_name.endswith('.jpeg')):
|
||
continue
|
||
|
||
# 模型预测
|
||
results = model(input_path, task="detect", conf=conf_threshold)
|
||
|
||
image_annotations = [] # 当前图像的标注信息
|
||
|
||
# 读取原始图像
|
||
image = cv2.imread(input_path)
|
||
|
||
# 获取图像尺寸
|
||
image_height, image_width = image.shape[:2]
|
||
|
||
# 标注的标志,检查是否有符合的目标
|
||
has_annotation = False
|
||
|
||
# 遍历预测结果
|
||
for result in results:
|
||
for box in result.boxes:
|
||
# 提取边界框信息和分类标签
|
||
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 坐标转换为整数
|
||
label = model.names[int(box.cls[0])] # 类别名称
|
||
print(f"label:" + label)
|
||
conf = box.conf[0] # 置信度
|
||
|
||
# 检查是否为目标标签
|
||
if label in target_labels:
|
||
color = label_colors[label] # 获取对应标签的颜色
|
||
has_annotation = True
|
||
|
||
# 绘制边界框
|
||
# color = (0, 255, 0) # 绿色
|
||
# 随机生成颜色
|
||
# color = tuple(random.randint(0, 255) for _ in range(3)) # 生成随机颜色
|
||
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
||
|
||
# 在边界框上方绘制标签和置信度
|
||
label_text = f"{label} {conf:.2f}"
|
||
print(f"label_text:" + label_text)
|
||
cv2.putText(image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 2, color, 1)
|
||
|
||
# 转换为归一化坐标并保存
|
||
normalized_x = x1 / image_width
|
||
normalized_y = y1 / image_height
|
||
normalized_width = (x2 - x1) / image_width
|
||
normalized_height = (y2 - y1) / image_height
|
||
|
||
annotation = {
|
||
"label": label,
|
||
"x": normalized_x * 100,
|
||
"y": normalized_y * 100,
|
||
"width": normalized_width * 100,
|
||
"height": normalized_height * 100,
|
||
}
|
||
image_annotations.append(annotation)
|
||
|
||
# 保存标注信息到结果字典
|
||
annotation_results[input_path] = image_annotations if image_annotations else []
|
||
|
||
# 如果存在符合的标注目标,保存标注后的图像
|
||
if has_annotation:
|
||
output_path = os.path.join(output_folder, file_name)
|
||
cv2.imwrite(output_path, image)
|
||
|
||
print(f"智能标注完成!标注后的图像保存在:{output_folder}")
|
||
return annotation_results
|
||
|
||
model_path = "./model/sishu_thin_24_12.onnx" # 你的YOLOv8模型权重路径
|
||
input_folder = "./dataset" # 输入文件夹路径
|
||
output_folder = "./output" # 输出文件夹路径
|
||
target_labels = ["tiaojuan", "zhujiesi", "yulinwen"] # 用户指定的目标标签
|
||
annotation_results = intelligent_annotation_with_filter(model_path, input_folder, output_folder, target_labels)
|
||
print(f"标注结果:annotation_results:{annotation_results}")
|
||
|
||
print(f"标注结果:./dataset/1024072305_0724123933.jpg annotation_results:{annotation_results['./dataset/1024072305_0724123933.jpg']}")
|
||
|
||
|