from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional import cv2 from ultralytics import YOLO import requests import numpy as np import time from contextlib import asynccontextmanager import logging import os from logging.handlers import RotatingFileHandler from urllib.parse import urlparse # 设置日志配置 log_dir = "logs" os.makedirs(log_dir, exist_ok=True) # 配置日志 log_file = os.path.join(log_dir, "app.log") logger = logging.getLogger("Sishu_AI_Annotation") logger.setLevel(logging.DEBUG) # 设置文件回滚,每个日志文件最大10MB,最多保留5个文件 handler = RotatingFileHandler(log_file, maxBytes=10 * 1024 * 1024, backupCount=5) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) @asynccontextmanager async def app_lifespan(app: FastAPI): """定义应用的生命周期事件""" global model logger.info("正在加载模型...") model = YOLO("./model/sishu_thin_24_12.onnx") # 加载YOLO模型 logger.info("模型加载完成") yield # 生命周期的运行阶段 # 在这里可以执行任何必要的清理操作 # FastAPI应用 app = FastAPI(lifespan=app_lifespan) # 定义请求参数模型 class AnnotateRequest(BaseModel): image_urls: List[str] # 图片的网络路径列表 target_labels: Optional[List[str]] = None # 标签列表 conf_threshold: Optional[float] = 0.3 # 模型推理的阈值 # 为每个标签生成固定颜色(红色,黄色,蓝色,绿色) def generate_label_colors(labels): colors = { "red": (255, 0, 0), "yellow": (255, 255, 0), "blue": (0, 0, 255), "green": (0, 255, 0), } label_colors = {} color_list = list(colors.values()) for i, label in enumerate(labels): label_colors[label] = color_list[i % len(color_list)] # 循环分配颜色 return label_colors # 通过URL加载图像 def load_image_from_url(image_url): try: response = requests.get(image_url) response.raise_for_status() image_array = np.frombuffer(response.content, np.uint8) image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) if image is None: logger.error(f"无法解码图像:{image_url}") return None return image except requests.exceptions.RequestException as e: logger.error(f"请求失败: {e}") return None # 对图像进行智能标注 def annotate_images(model, image_urls, target_labels=None, conf_threshold=0.1): annotation_results = {} all_labels = model.names.values() label_colors = generate_label_colors(all_labels) for image_url in image_urls: # 假设 image_path 可能是一个 URL 或者本地路径 if image_url.startswith("http"): # 如果是URL,加载在线图片 start_time = time.time() # 记录开始时间 image = load_image_from_url(image_url) end_time = time.time() # 记录结束时间 load_time = end_time - start_time logger.debug(f"加载图像 {image_url} 的时间: {load_time:.3f} 秒") else: # 如果是本地路径,使用cv2读取 image = cv2.imread(image_url) if image is None: logger.error(f"无法读取图像:{image_url}") continue image_height, image_width = image.shape[:2] results = model(image, task="detect", conf=conf_threshold) image_annotations = [] path = urlparse(image_url).path file_name = os.path.basename(path) for result in results: for box in result.boxes: x1, y1, x2, y2 = map(int, box.xyxy[0]) label = model.names[int(box.cls[0])] conf = box.conf[0] if target_labels is None or label in target_labels: color = label_colors[label] cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) label_text = f"{label} {conf:.2f}" cv2.putText(image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) normalized_x = x1 / image_width normalized_y = y1 / image_height normalized_width = (x2 - x1) / image_width normalized_height = (y2 - y1) / image_height annotation = { "label": label, "x": normalized_x * 100, "y": normalized_y * 100, "width": normalized_width * 100, "height": normalized_height * 100, } image_annotations.append(annotation) annotation_results[image_url] = image_annotations if image_annotations else [] # 保存标注后的图像 output_path = os.path.join("./output", os.path.basename(image_url)) cv2.imwrite(output_path, image) logger.info("图片路径:{}".format(output_path)) return annotation_results # FastAPI端点 - 图像标注接口 @app.post("/annotate") async def annotate(request: AnnotateRequest): try: logger.info("开始处理请求: %s", request.dict()) annotation_results = annotate_images( model=model, image_urls=request.image_urls, target_labels=request.target_labels, conf_threshold=request.conf_threshold ) logger.info("请求处理完成") logger.info("annotations:{}", annotation_results) return {"annotations": annotation_results} except Exception as e: logger.error(f"发生错误: {e}") raise HTTPException(status_code=500, detail=str(e)) # 示例用法: 调用接口时会返回JSON格式的标注结果,包含图片的路径和标注信息。