IntelligentAnnotate/sishu_label_web.py

171 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import cv2
from ultralytics import YOLO
import requests
import numpy as np
import time
from contextlib import asynccontextmanager
import logging
import os
from logging.handlers import RotatingFileHandler
from urllib.parse import urlparse
# 设置日志配置
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
# 配置日志
log_file = os.path.join(log_dir, "app.log")
logger = logging.getLogger("Sishu_AI_Annotation")
logger.setLevel(logging.DEBUG)
# 设置文件回滚每个日志文件最大10MB最多保留5个文件
handler = RotatingFileHandler(log_file, maxBytes=10 * 1024 * 1024, backupCount=5)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
@asynccontextmanager
async def app_lifespan(app: FastAPI):
"""定义应用的生命周期事件"""
global model
logger.info("正在加载模型...")
model = YOLO("./model/sishu_thin_24_12.onnx") # 加载YOLO模型
logger.info("模型加载完成")
yield # 生命周期的运行阶段
# 在这里可以执行任何必要的清理操作
# FastAPI应用
app = FastAPI(lifespan=app_lifespan)
# 定义请求参数模型
class AnnotateRequest(BaseModel):
image_urls: List[str] # 图片的网络路径列表
target_labels: Optional[List[str]] = None # 标签列表
conf_threshold: Optional[float] = 0.3 # 模型推理的阈值
# 为每个标签生成固定颜色(红色,黄色,蓝色,绿色)
def generate_label_colors(labels):
colors = {
"red": (255, 0, 0),
"yellow": (255, 255, 0),
"blue": (0, 0, 255),
"green": (0, 255, 0),
}
label_colors = {}
color_list = list(colors.values())
for i, label in enumerate(labels):
label_colors[label] = color_list[i % len(color_list)] # 循环分配颜色
return label_colors
# 通过URL加载图像
def load_image_from_url(image_url):
try:
response = requests.get(image_url)
response.raise_for_status()
image_array = np.frombuffer(response.content, np.uint8)
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
if image is None:
logger.error(f"无法解码图像:{image_url}")
return None
return image
except requests.exceptions.RequestException as e:
logger.error(f"请求失败: {e}")
return None
# 对图像进行智能标注
def annotate_images(model, image_urls, target_labels=None, conf_threshold=0.1):
annotation_results = {}
all_labels = model.names.values()
label_colors = generate_label_colors(all_labels)
logger.debug(f"标注的置信度阈值为:" + str(conf_threshold))
for image_url in image_urls:
# 假设 image_path 可能是一个 URL 或者本地路径
if image_url.startswith("http"):
# 如果是URL加载在线图片
start_time = time.time() # 记录开始时间
image = load_image_from_url(image_url)
end_time = time.time() # 记录结束时间
load_time = end_time - start_time
logger.debug(f"加载图像 {image_url} 的时间: {load_time:.3f}")
else:
# 如果是本地路径使用cv2读取
image = cv2.imread(image_url)
if image is None:
logger.error(f"无法读取图像:{image_url}")
continue
image_height, image_width = image.shape[:2]
results = model(image, task="detect", conf=conf_threshold)
image_annotations = []
path = urlparse(image_url).path
file_name = os.path.basename(path)
for result in results:
for box in result.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
label = model.names[int(box.cls[0])]
conf = box.conf[0]
if target_labels is None or label in target_labels:
color = label_colors[label]
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
label_text = f"{label} {conf:.2f}"
cv2.putText(image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
normalized_x = x1 / image_width
normalized_y = y1 / image_height
normalized_width = (x2 - x1) / image_width
normalized_height = (y2 - y1) / image_height
annotation = {
"label": label,
"x": normalized_x * 100,
"y": normalized_y * 100,
"width": normalized_width * 100,
"height": normalized_height * 100,
}
image_annotations.append(annotation)
annotation_results[image_url] = image_annotations if image_annotations else [] # 保存标注后的图像
output_path = os.path.join("./output", os.path.basename(image_url))
cv2.imwrite(output_path, image)
logger.info("图片路径:{}".format(output_path))
return annotation_results
# FastAPI端点 - 图像标注接口
@app.post("/annotate")
async def annotate(request: AnnotateRequest):
try:
logger.info("开始处理请求: %s", request.dict())
annotation_results = annotate_images(
model=model,
image_urls=request.image_urls,
target_labels=request.target_labels,
conf_threshold=request.conf_threshold
)
logger.info("请求处理完成")
logger.info("annotations:{}", annotation_results)
return {"annotations": annotation_results}
except Exception as e:
logger.error(f"发生错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
# 示例用法: 调用接口时会返回JSON格式的标注结果包含图片的路径和标注信息。