170 lines
5.9 KiB
Python
170 lines
5.9 KiB
Python
from fastapi import FastAPI, HTTPException
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional
|
||
import cv2
|
||
from ultralytics import YOLO
|
||
import requests
|
||
import numpy as np
|
||
import time
|
||
from contextlib import asynccontextmanager
|
||
import logging
|
||
import os
|
||
from logging.handlers import RotatingFileHandler
|
||
from urllib.parse import urlparse
|
||
|
||
# 设置日志配置
|
||
log_dir = "logs"
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
|
||
# 配置日志
|
||
log_file = os.path.join(log_dir, "app.log")
|
||
logger = logging.getLogger("Sishu_AI_Annotation")
|
||
logger.setLevel(logging.DEBUG)
|
||
|
||
# 设置文件回滚,每个日志文件最大10MB,最多保留5个文件
|
||
handler = RotatingFileHandler(log_file, maxBytes=10 * 1024 * 1024, backupCount=5)
|
||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
handler.setFormatter(formatter)
|
||
logger.addHandler(handler)
|
||
|
||
@asynccontextmanager
|
||
async def app_lifespan(app: FastAPI):
|
||
"""定义应用的生命周期事件"""
|
||
global model
|
||
logger.info("正在加载模型...")
|
||
model = YOLO("./model/sishu_thin_24_12.onnx") # 加载YOLO模型
|
||
logger.info("模型加载完成")
|
||
yield # 生命周期的运行阶段
|
||
# 在这里可以执行任何必要的清理操作
|
||
|
||
# FastAPI应用
|
||
app = FastAPI(lifespan=app_lifespan)
|
||
|
||
# 定义请求参数模型
|
||
class AnnotateRequest(BaseModel):
|
||
image_urls: List[str] # 图片的网络路径列表
|
||
target_labels: Optional[List[str]] = None # 标签列表
|
||
conf_threshold: Optional[float] = 0.3 # 模型推理的阈值
|
||
|
||
# 为每个标签生成固定颜色(红色,黄色,蓝色,绿色)
|
||
def generate_label_colors(labels):
|
||
colors = {
|
||
"red": (255, 0, 0),
|
||
"yellow": (255, 255, 0),
|
||
"blue": (0, 0, 255),
|
||
"green": (0, 255, 0),
|
||
}
|
||
label_colors = {}
|
||
color_list = list(colors.values())
|
||
for i, label in enumerate(labels):
|
||
label_colors[label] = color_list[i % len(color_list)] # 循环分配颜色
|
||
|
||
return label_colors
|
||
|
||
|
||
# 通过URL加载图像
|
||
def load_image_from_url(image_url):
|
||
try:
|
||
response = requests.get(image_url)
|
||
response.raise_for_status()
|
||
image_array = np.frombuffer(response.content, np.uint8)
|
||
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||
|
||
if image is None:
|
||
logger.error(f"无法解码图像:{image_url}")
|
||
return None
|
||
|
||
return image
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"请求失败: {e}")
|
||
return None
|
||
|
||
|
||
# 对图像进行智能标注
|
||
def annotate_images(model, image_urls, target_labels=None, conf_threshold=0.1):
|
||
annotation_results = {}
|
||
all_labels = model.names.values()
|
||
label_colors = generate_label_colors(all_labels)
|
||
|
||
for image_url in image_urls:
|
||
# 假设 image_path 可能是一个 URL 或者本地路径
|
||
if image_url.startswith("http"):
|
||
# 如果是URL,加载在线图片
|
||
start_time = time.time() # 记录开始时间
|
||
image = load_image_from_url(image_url)
|
||
end_time = time.time() # 记录结束时间
|
||
load_time = end_time - start_time
|
||
logger.debug(f"加载图像 {image_url} 的时间: {load_time:.3f} 秒")
|
||
else:
|
||
# 如果是本地路径,使用cv2读取
|
||
image = cv2.imread(image_url)
|
||
|
||
if image is None:
|
||
logger.error(f"无法读取图像:{image_url}")
|
||
continue
|
||
|
||
image_height, image_width = image.shape[:2]
|
||
results = model(image, task="detect", conf=conf_threshold)
|
||
|
||
image_annotations = []
|
||
|
||
path = urlparse(image_url).path
|
||
file_name = os.path.basename(path)
|
||
|
||
for result in results:
|
||
for box in result.boxes:
|
||
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||
label = model.names[int(box.cls[0])]
|
||
conf = box.conf[0]
|
||
|
||
if target_labels is None or label in target_labels:
|
||
color = label_colors[label]
|
||
|
||
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
||
label_text = f"{label} {conf:.2f}"
|
||
cv2.putText(image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
||
|
||
normalized_x = x1 / image_width
|
||
normalized_y = y1 / image_height
|
||
normalized_width = (x2 - x1) / image_width
|
||
normalized_height = (y2 - y1) / image_height
|
||
|
||
annotation = {
|
||
"label": label,
|
||
"x": normalized_x * 100,
|
||
"y": normalized_y * 100,
|
||
"width": normalized_width * 100,
|
||
"height": normalized_height * 100,
|
||
}
|
||
image_annotations.append(annotation)
|
||
|
||
annotation_results[image_url] = image_annotations if image_annotations else [] # 保存标注后的图像
|
||
output_path = os.path.join("./output", os.path.basename(image_url))
|
||
cv2.imwrite(output_path, image)
|
||
logger.info("图片路径:{}".format(output_path))
|
||
|
||
|
||
return annotation_results
|
||
|
||
# FastAPI端点 - 图像标注接口
|
||
@app.post("/annotate")
|
||
async def annotate(request: AnnotateRequest):
|
||
try:
|
||
logger.info("开始处理请求: %s", request.dict())
|
||
annotation_results = annotate_images(
|
||
model=model,
|
||
image_urls=request.image_urls,
|
||
target_labels=request.target_labels,
|
||
conf_threshold=request.conf_threshold
|
||
)
|
||
logger.info("请求处理完成")
|
||
logger.info("annotations:{}", annotation_results)
|
||
return {"annotations": annotation_results}
|
||
|
||
except Exception as e:
|
||
logger.error(f"发生错误: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# 示例用法: 调用接口时会返回JSON格式的标注结果,包含图片的路径和标注信息。
|