commit 1dadf70601fc5fe762c8fa30e5d4b8e115f81de8 Author: weiweiw <14335254+weiweiw22@user.noreply.gitee.com> Date: Thu Dec 12 17:05:39 2024 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e9a196f --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +*.log +*.log.* +*.bak +logs +**/target +.idea/ +*.class +*.manifest +*.spec +.DS_Store + + + + + + diff --git a/dataset/1014072304_0724045231.jpg b/dataset/1014072304_0724045231.jpg new file mode 100644 index 0000000..d192653 Binary files /dev/null and b/dataset/1014072304_0724045231.jpg differ diff --git a/dataset/1024072305_0724123756.jpg b/dataset/1024072305_0724123756.jpg new file mode 100644 index 0000000..afda32d Binary files /dev/null and b/dataset/1024072305_0724123756.jpg differ diff --git a/dataset/1024072305_0724123822.jpg b/dataset/1024072305_0724123822.jpg new file mode 100644 index 0000000..c9fddfa Binary files /dev/null and b/dataset/1024072305_0724123822.jpg differ diff --git a/dataset/1024072305_0724123933.jpg b/dataset/1024072305_0724123933.jpg new file mode 100644 index 0000000..488619c Binary files /dev/null and b/dataset/1024072305_0724123933.jpg differ diff --git a/model/sishu_middle_24_12.onnx b/model/sishu_middle_24_12.onnx new file mode 100644 index 0000000..d6334b8 Binary files /dev/null and b/model/sishu_middle_24_12.onnx differ diff --git a/model/sishu_thin_24_12.onnx b/model/sishu_thin_24_12.onnx new file mode 100644 index 0000000..5bad2ba Binary files /dev/null and b/model/sishu_thin_24_12.onnx differ diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..7eed776 --- /dev/null +++ b/readme.md @@ -0,0 +1,28 @@ +## 简介 +本项目主要用于实现智能标注,支持本地标注的测试和web标注的服务提供。’ +## model +model 目录用于存放训练好的模型,以及模型对应的配置文件。 + +## dataset +dataset 目录用于存放测试数据集。 + +## output +output 目录用于存放含有标注结果的图片。 + +## logs +logs 目录用于存放标注过程中产生的日志文件。每个日志文件最大10MB,最多保留5个文件。 + +## 文件说明 +sishu_label_local.py 为本地标注的测试代码,目的是将dataset里的丝束的图片进行标注,并将标注结果包括标签和矩形框保存到output目录下。 +sishu_label_web.py 为web标注的服务提供代码,目的是根据传入的网络图像路径按照传入的标签进行标注,返回标注结果。 + + +## 启动sishu_label_local的方法 +python sishu_label_local.py + +## 启动sishu_label_web的方法 +单进程启动: +uvicorn sishu_label_web:app --reload --host '0.0.0.0' --port 8081 + +多进程启动: +uvicorn sishu_label_web:app --host 0.0.0.0 --port 8081 --workers 5 \ No newline at end of file diff --git a/sishu_label_local.py b/sishu_label_local.py new file mode 100644 index 0000000..7e7343e --- /dev/null +++ b/sishu_label_local.py @@ -0,0 +1,130 @@ +import os +from ultralytics import YOLO +import cv2 +import random + +def generate_label_colors(labels): + """为每个标签生成固定颜色(红色,黄色,蓝色,绿色)""" + colors = { + "red": (255, 0, 0), + "yellow": (255, 255, 0), + "blue": (0, 0, 255), + "green": (0, 255, 0), + } + # 为标签分配颜色,确保颜色在红、黄、蓝、绿之间循环 + label_colors = {} + color_list = list(colors.values()) + for i, label in enumerate(labels): + print(f"generate_label_colors, label: {label}") + label_colors[label] = color_list[i % len(color_list)] # 循环分配颜色 + + return label_colors + +def intelligent_annotation_with_filter(model_path, input_folder, output_folder, target_labels, conf_threshold=0.3): + """ + 使用 YOLOv8 模型对本地文件夹图像进行智能标注,并按指定标签过滤输出。 + + :param model_path: YOLOv8 模型文件路径 + :param input_folder: 输入文件夹路径,包含待标注的图像 + :param output_folder: 输出文件夹路径,存储标注后的图像 + :param target_labels: 用户指定需要标注的目标类别列表 + :param conf_threshold: 置信度阈值,默认 0.5 + """ + # 加载 YOLO 模型 + model = YOLO(model_path) + + # 获取所有可能的标签 + all_labels = model.names.values() + print("所有可能的标签:", all_labels) + label_colors = generate_label_colors(all_labels) + print("所有可能的标签对应的颜色:", label_colors) + + # 保存标注结果的字典 + annotation_results = {} + + # 创建输出文件夹(如果不存在) + os.makedirs(output_folder, exist_ok=True) + + # 遍历输入文件夹中的所有图像 + for file_name in os.listdir(input_folder): + input_path = os.path.join(input_folder, file_name) + + # 检查文件是否是图像格式 + if not (file_name.endswith('.jpg') or file_name.endswith('.png') or file_name.endswith('.jpeg')): + continue + + # 模型预测 + results = model(input_path, task="detect", conf=conf_threshold) + + image_annotations = [] # 当前图像的标注信息 + + # 读取原始图像 + image = cv2.imread(input_path) + + # 获取图像尺寸 + image_height, image_width = image.shape[:2] + + # 标注的标志,检查是否有符合的目标 + has_annotation = False + + # 遍历预测结果 + for result in results: + for box in result.boxes: + # 提取边界框信息和分类标签 + x1, y1, x2, y2 = map(int, box.xyxy[0]) # 坐标转换为整数 + label = model.names[int(box.cls[0])] # 类别名称 + print(f"label:" + label) + conf = box.conf[0] # 置信度 + + # 检查是否为目标标签 + if label in target_labels: + color = label_colors[label] # 获取对应标签的颜色 + has_annotation = True + + # 绘制边界框 + # color = (0, 255, 0) # 绿色 + # 随机生成颜色 + # color = tuple(random.randint(0, 255) for _ in range(3)) # 生成随机颜色 + cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) + + # 在边界框上方绘制标签和置信度 + label_text = f"{label} {conf:.2f}" + print(f"label_text:" + label_text) + cv2.putText(image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 2, color, 1) + + # 转换为归一化坐标并保存 + normalized_x = x1 / image_width + normalized_y = y1 / image_height + normalized_width = (x2 - x1) / image_width + normalized_height = (y2 - y1) / image_height + + annotation = { + "label": label, + "x": normalized_x * 100, + "y": normalized_y * 100, + "width": normalized_width * 100, + "height": normalized_height * 100, + } + image_annotations.append(annotation) + + # 保存标注信息到结果字典 + annotation_results[input_path] = image_annotations if image_annotations else [] + + # 如果存在符合的标注目标,保存标注后的图像 + if has_annotation: + output_path = os.path.join(output_folder, file_name) + cv2.imwrite(output_path, image) + + print(f"智能标注完成!标注后的图像保存在:{output_folder}") + return annotation_results + +model_path = "./model/sishu_thin_24_12.onnx" # 你的YOLOv8模型权重路径 +input_folder = "./dataset" # 输入文件夹路径 +output_folder = "./output" # 输出文件夹路径 +target_labels = ["tiaojuan", "zhujiesi", "yulinwen"] # 用户指定的目标标签 +annotation_results = intelligent_annotation_with_filter(model_path, input_folder, output_folder, target_labels) +print(f"标注结果:annotation_results:{annotation_results}") + +print(f"标注结果:./dataset/1024072305_0724123933.jpg annotation_results:{annotation_results['./dataset/1024072305_0724123933.jpg']}") + + diff --git a/sishu_label_web.py b/sishu_label_web.py new file mode 100644 index 0000000..a9135bd --- /dev/null +++ b/sishu_label_web.py @@ -0,0 +1,169 @@ +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from typing import List, Optional +import cv2 +from ultralytics import YOLO +import requests +import numpy as np +import time +from contextlib import asynccontextmanager +import logging +import os +from logging.handlers import RotatingFileHandler +from urllib.parse import urlparse + +# 设置日志配置 +log_dir = "logs" +os.makedirs(log_dir, exist_ok=True) + +# 配置日志 +log_file = os.path.join(log_dir, "app.log") +logger = logging.getLogger("Sishu_AI_Annotation") +logger.setLevel(logging.DEBUG) + +# 设置文件回滚,每个日志文件最大10MB,最多保留5个文件 +handler = RotatingFileHandler(log_file, maxBytes=10 * 1024 * 1024, backupCount=5) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) +logger.addHandler(handler) + +@asynccontextmanager +async def app_lifespan(app: FastAPI): + """定义应用的生命周期事件""" + global model + logger.info("正在加载模型...") + model = YOLO("./model/sishu_thin_24_12.onnx") # 加载YOLO模型 + logger.info("模型加载完成") + yield # 生命周期的运行阶段 + # 在这里可以执行任何必要的清理操作 + +# FastAPI应用 +app = FastAPI(lifespan=app_lifespan) + +# 定义请求参数模型 +class AnnotateRequest(BaseModel): + image_urls: List[str] # 图片的网络路径列表 + target_labels: Optional[List[str]] = None # 标签列表 + conf_threshold: Optional[float] = 0.3 # 模型推理的阈值 + +# 为每个标签生成固定颜色(红色,黄色,蓝色,绿色) +def generate_label_colors(labels): + colors = { + "red": (255, 0, 0), + "yellow": (255, 255, 0), + "blue": (0, 0, 255), + "green": (0, 255, 0), + } + label_colors = {} + color_list = list(colors.values()) + for i, label in enumerate(labels): + label_colors[label] = color_list[i % len(color_list)] # 循环分配颜色 + + return label_colors + + +# 通过URL加载图像 +def load_image_from_url(image_url): + try: + response = requests.get(image_url) + response.raise_for_status() + image_array = np.frombuffer(response.content, np.uint8) + image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + + if image is None: + logger.error(f"无法解码图像:{image_url}") + return None + + return image + except requests.exceptions.RequestException as e: + logger.error(f"请求失败: {e}") + return None + + +# 对图像进行智能标注 +def annotate_images(model, image_urls, target_labels=None, conf_threshold=0.1): + annotation_results = {} + all_labels = model.names.values() + label_colors = generate_label_colors(all_labels) + + for image_url in image_urls: + # 假设 image_path 可能是一个 URL 或者本地路径 + if image_url.startswith("http"): + # 如果是URL,加载在线图片 + start_time = time.time() # 记录开始时间 + image = load_image_from_url(image_url) + end_time = time.time() # 记录结束时间 + load_time = end_time - start_time + logger.debug(f"加载图像 {image_url} 的时间: {load_time:.3f} 秒") + else: + # 如果是本地路径,使用cv2读取 + image = cv2.imread(image_url) + + if image is None: + logger.error(f"无法读取图像:{image_url}") + continue + + image_height, image_width = image.shape[:2] + results = model(image, task="detect", conf=conf_threshold) + + image_annotations = [] + + path = urlparse(image_url).path + file_name = os.path.basename(path) + + for result in results: + for box in result.boxes: + x1, y1, x2, y2 = map(int, box.xyxy[0]) + label = model.names[int(box.cls[0])] + conf = box.conf[0] + + if target_labels is None or label in target_labels: + color = label_colors[label] + + cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) + label_text = f"{label} {conf:.2f}" + cv2.putText(image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) + + normalized_x = x1 / image_width + normalized_y = y1 / image_height + normalized_width = (x2 - x1) / image_width + normalized_height = (y2 - y1) / image_height + + annotation = { + "label": label, + "x": normalized_x * 100, + "y": normalized_y * 100, + "width": normalized_width * 100, + "height": normalized_height * 100, + } + image_annotations.append(annotation) + + annotation_results[image_url] = image_annotations if image_annotations else [] # 保存标注后的图像 + output_path = os.path.join("./output", os.path.basename(image_url)) + cv2.imwrite(output_path, image) + logger.info("图片路径:{}".format(output_path)) + + + return annotation_results + +# FastAPI端点 - 图像标注接口 +@app.post("/annotate") +async def annotate(request: AnnotateRequest): + try: + logger.info("开始处理请求: %s", request.dict()) + annotation_results = annotate_images( + model=model, + image_urls=request.image_urls, + target_labels=request.target_labels, + conf_threshold=request.conf_threshold + ) + logger.info("请求处理完成") + logger.info("annotations:{}", annotation_results) + return {"annotations": annotation_results} + + except Exception as e: + logger.error(f"发生错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# 示例用法: 调用接口时会返回JSON格式的标注结果,包含图片的路径和标注信息。