first commit
This commit is contained in:
commit
1dadf70601
|
|
@ -0,0 +1,16 @@
|
||||||
|
*.log
|
||||||
|
*.log.*
|
||||||
|
*.bak
|
||||||
|
logs
|
||||||
|
**/target
|
||||||
|
.idea/
|
||||||
|
*.class
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 4.6 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 4.4 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 4.0 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 4.4 MiB |
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,28 @@
|
||||||
|
## 简介
|
||||||
|
本项目主要用于实现智能标注,支持本地标注的测试和web标注的服务提供。’
|
||||||
|
## model
|
||||||
|
model 目录用于存放训练好的模型,以及模型对应的配置文件。
|
||||||
|
|
||||||
|
## dataset
|
||||||
|
dataset 目录用于存放测试数据集。
|
||||||
|
|
||||||
|
## output
|
||||||
|
output 目录用于存放含有标注结果的图片。
|
||||||
|
|
||||||
|
## logs
|
||||||
|
logs 目录用于存放标注过程中产生的日志文件。每个日志文件最大10MB,最多保留5个文件。
|
||||||
|
|
||||||
|
## 文件说明
|
||||||
|
sishu_label_local.py 为本地标注的测试代码,目的是将dataset里的丝束的图片进行标注,并将标注结果包括标签和矩形框保存到output目录下。
|
||||||
|
sishu_label_web.py 为web标注的服务提供代码,目的是根据传入的网络图像路径按照传入的标签进行标注,返回标注结果。
|
||||||
|
|
||||||
|
|
||||||
|
## 启动sishu_label_local的方法
|
||||||
|
python sishu_label_local.py
|
||||||
|
|
||||||
|
## 启动sishu_label_web的方法
|
||||||
|
单进程启动:
|
||||||
|
uvicorn sishu_label_web:app --reload --host '0.0.0.0' --port 8081
|
||||||
|
|
||||||
|
多进程启动:
|
||||||
|
uvicorn sishu_label_web:app --host 0.0.0.0 --port 8081 --workers 5
|
||||||
|
|
@ -0,0 +1,130 @@
|
||||||
|
import os
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import cv2
|
||||||
|
import random
|
||||||
|
|
||||||
|
def generate_label_colors(labels):
|
||||||
|
"""为每个标签生成固定颜色(红色,黄色,蓝色,绿色)"""
|
||||||
|
colors = {
|
||||||
|
"red": (255, 0, 0),
|
||||||
|
"yellow": (255, 255, 0),
|
||||||
|
"blue": (0, 0, 255),
|
||||||
|
"green": (0, 255, 0),
|
||||||
|
}
|
||||||
|
# 为标签分配颜色,确保颜色在红、黄、蓝、绿之间循环
|
||||||
|
label_colors = {}
|
||||||
|
color_list = list(colors.values())
|
||||||
|
for i, label in enumerate(labels):
|
||||||
|
print(f"generate_label_colors, label: {label}")
|
||||||
|
label_colors[label] = color_list[i % len(color_list)] # 循环分配颜色
|
||||||
|
|
||||||
|
return label_colors
|
||||||
|
|
||||||
|
def intelligent_annotation_with_filter(model_path, input_folder, output_folder, target_labels, conf_threshold=0.3):
|
||||||
|
"""
|
||||||
|
使用 YOLOv8 模型对本地文件夹图像进行智能标注,并按指定标签过滤输出。
|
||||||
|
|
||||||
|
:param model_path: YOLOv8 模型文件路径
|
||||||
|
:param input_folder: 输入文件夹路径,包含待标注的图像
|
||||||
|
:param output_folder: 输出文件夹路径,存储标注后的图像
|
||||||
|
:param target_labels: 用户指定需要标注的目标类别列表
|
||||||
|
:param conf_threshold: 置信度阈值,默认 0.5
|
||||||
|
"""
|
||||||
|
# 加载 YOLO 模型
|
||||||
|
model = YOLO(model_path)
|
||||||
|
|
||||||
|
# 获取所有可能的标签
|
||||||
|
all_labels = model.names.values()
|
||||||
|
print("所有可能的标签:", all_labels)
|
||||||
|
label_colors = generate_label_colors(all_labels)
|
||||||
|
print("所有可能的标签对应的颜色:", label_colors)
|
||||||
|
|
||||||
|
# 保存标注结果的字典
|
||||||
|
annotation_results = {}
|
||||||
|
|
||||||
|
# 创建输出文件夹(如果不存在)
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# 遍历输入文件夹中的所有图像
|
||||||
|
for file_name in os.listdir(input_folder):
|
||||||
|
input_path = os.path.join(input_folder, file_name)
|
||||||
|
|
||||||
|
# 检查文件是否是图像格式
|
||||||
|
if not (file_name.endswith('.jpg') or file_name.endswith('.png') or file_name.endswith('.jpeg')):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 模型预测
|
||||||
|
results = model(input_path, task="detect", conf=conf_threshold)
|
||||||
|
|
||||||
|
image_annotations = [] # 当前图像的标注信息
|
||||||
|
|
||||||
|
# 读取原始图像
|
||||||
|
image = cv2.imread(input_path)
|
||||||
|
|
||||||
|
# 获取图像尺寸
|
||||||
|
image_height, image_width = image.shape[:2]
|
||||||
|
|
||||||
|
# 标注的标志,检查是否有符合的目标
|
||||||
|
has_annotation = False
|
||||||
|
|
||||||
|
# 遍历预测结果
|
||||||
|
for result in results:
|
||||||
|
for box in result.boxes:
|
||||||
|
# 提取边界框信息和分类标签
|
||||||
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 坐标转换为整数
|
||||||
|
label = model.names[int(box.cls[0])] # 类别名称
|
||||||
|
print(f"label:" + label)
|
||||||
|
conf = box.conf[0] # 置信度
|
||||||
|
|
||||||
|
# 检查是否为目标标签
|
||||||
|
if label in target_labels:
|
||||||
|
color = label_colors[label] # 获取对应标签的颜色
|
||||||
|
has_annotation = True
|
||||||
|
|
||||||
|
# 绘制边界框
|
||||||
|
# color = (0, 255, 0) # 绿色
|
||||||
|
# 随机生成颜色
|
||||||
|
# color = tuple(random.randint(0, 255) for _ in range(3)) # 生成随机颜色
|
||||||
|
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
||||||
|
|
||||||
|
# 在边界框上方绘制标签和置信度
|
||||||
|
label_text = f"{label} {conf:.2f}"
|
||||||
|
print(f"label_text:" + label_text)
|
||||||
|
cv2.putText(image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 2, color, 1)
|
||||||
|
|
||||||
|
# 转换为归一化坐标并保存
|
||||||
|
normalized_x = x1 / image_width
|
||||||
|
normalized_y = y1 / image_height
|
||||||
|
normalized_width = (x2 - x1) / image_width
|
||||||
|
normalized_height = (y2 - y1) / image_height
|
||||||
|
|
||||||
|
annotation = {
|
||||||
|
"label": label,
|
||||||
|
"x": normalized_x * 100,
|
||||||
|
"y": normalized_y * 100,
|
||||||
|
"width": normalized_width * 100,
|
||||||
|
"height": normalized_height * 100,
|
||||||
|
}
|
||||||
|
image_annotations.append(annotation)
|
||||||
|
|
||||||
|
# 保存标注信息到结果字典
|
||||||
|
annotation_results[input_path] = image_annotations if image_annotations else []
|
||||||
|
|
||||||
|
# 如果存在符合的标注目标,保存标注后的图像
|
||||||
|
if has_annotation:
|
||||||
|
output_path = os.path.join(output_folder, file_name)
|
||||||
|
cv2.imwrite(output_path, image)
|
||||||
|
|
||||||
|
print(f"智能标注完成!标注后的图像保存在:{output_folder}")
|
||||||
|
return annotation_results
|
||||||
|
|
||||||
|
model_path = "./model/sishu_thin_24_12.onnx" # 你的YOLOv8模型权重路径
|
||||||
|
input_folder = "./dataset" # 输入文件夹路径
|
||||||
|
output_folder = "./output" # 输出文件夹路径
|
||||||
|
target_labels = ["tiaojuan", "zhujiesi", "yulinwen"] # 用户指定的目标标签
|
||||||
|
annotation_results = intelligent_annotation_with_filter(model_path, input_folder, output_folder, target_labels)
|
||||||
|
print(f"标注结果:annotation_results:{annotation_results}")
|
||||||
|
|
||||||
|
print(f"标注结果:./dataset/1024072305_0724123933.jpg annotation_results:{annotation_results['./dataset/1024072305_0724123933.jpg']}")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,169 @@
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional
|
||||||
|
import cv2
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import requests
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
# 设置日志配置
|
||||||
|
log_dir = "logs"
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
log_file = os.path.join(log_dir, "app.log")
|
||||||
|
logger = logging.getLogger("Sishu_AI_Annotation")
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
# 设置文件回滚,每个日志文件最大10MB,最多保留5个文件
|
||||||
|
handler = RotatingFileHandler(log_file, maxBytes=10 * 1024 * 1024, backupCount=5)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def app_lifespan(app: FastAPI):
|
||||||
|
"""定义应用的生命周期事件"""
|
||||||
|
global model
|
||||||
|
logger.info("正在加载模型...")
|
||||||
|
model = YOLO("./model/sishu_thin_24_12.onnx") # 加载YOLO模型
|
||||||
|
logger.info("模型加载完成")
|
||||||
|
yield # 生命周期的运行阶段
|
||||||
|
# 在这里可以执行任何必要的清理操作
|
||||||
|
|
||||||
|
# FastAPI应用
|
||||||
|
app = FastAPI(lifespan=app_lifespan)
|
||||||
|
|
||||||
|
# 定义请求参数模型
|
||||||
|
class AnnotateRequest(BaseModel):
|
||||||
|
image_urls: List[str] # 图片的网络路径列表
|
||||||
|
target_labels: Optional[List[str]] = None # 标签列表
|
||||||
|
conf_threshold: Optional[float] = 0.3 # 模型推理的阈值
|
||||||
|
|
||||||
|
# 为每个标签生成固定颜色(红色,黄色,蓝色,绿色)
|
||||||
|
def generate_label_colors(labels):
|
||||||
|
colors = {
|
||||||
|
"red": (255, 0, 0),
|
||||||
|
"yellow": (255, 255, 0),
|
||||||
|
"blue": (0, 0, 255),
|
||||||
|
"green": (0, 255, 0),
|
||||||
|
}
|
||||||
|
label_colors = {}
|
||||||
|
color_list = list(colors.values())
|
||||||
|
for i, label in enumerate(labels):
|
||||||
|
label_colors[label] = color_list[i % len(color_list)] # 循环分配颜色
|
||||||
|
|
||||||
|
return label_colors
|
||||||
|
|
||||||
|
|
||||||
|
# 通过URL加载图像
|
||||||
|
def load_image_from_url(image_url):
|
||||||
|
try:
|
||||||
|
response = requests.get(image_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
image_array = np.frombuffer(response.content, np.uint8)
|
||||||
|
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
logger.error(f"无法解码图像:{image_url}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return image
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"请求失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# 对图像进行智能标注
|
||||||
|
def annotate_images(model, image_urls, target_labels=None, conf_threshold=0.1):
|
||||||
|
annotation_results = {}
|
||||||
|
all_labels = model.names.values()
|
||||||
|
label_colors = generate_label_colors(all_labels)
|
||||||
|
|
||||||
|
for image_url in image_urls:
|
||||||
|
# 假设 image_path 可能是一个 URL 或者本地路径
|
||||||
|
if image_url.startswith("http"):
|
||||||
|
# 如果是URL,加载在线图片
|
||||||
|
start_time = time.time() # 记录开始时间
|
||||||
|
image = load_image_from_url(image_url)
|
||||||
|
end_time = time.time() # 记录结束时间
|
||||||
|
load_time = end_time - start_time
|
||||||
|
logger.debug(f"加载图像 {image_url} 的时间: {load_time:.3f} 秒")
|
||||||
|
else:
|
||||||
|
# 如果是本地路径,使用cv2读取
|
||||||
|
image = cv2.imread(image_url)
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
logger.error(f"无法读取图像:{image_url}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_height, image_width = image.shape[:2]
|
||||||
|
results = model(image, task="detect", conf=conf_threshold)
|
||||||
|
|
||||||
|
image_annotations = []
|
||||||
|
|
||||||
|
path = urlparse(image_url).path
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
for box in result.boxes:
|
||||||
|
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||||||
|
label = model.names[int(box.cls[0])]
|
||||||
|
conf = box.conf[0]
|
||||||
|
|
||||||
|
if target_labels is None or label in target_labels:
|
||||||
|
color = label_colors[label]
|
||||||
|
|
||||||
|
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
||||||
|
label_text = f"{label} {conf:.2f}"
|
||||||
|
cv2.putText(image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
||||||
|
|
||||||
|
normalized_x = x1 / image_width
|
||||||
|
normalized_y = y1 / image_height
|
||||||
|
normalized_width = (x2 - x1) / image_width
|
||||||
|
normalized_height = (y2 - y1) / image_height
|
||||||
|
|
||||||
|
annotation = {
|
||||||
|
"label": label,
|
||||||
|
"x": normalized_x * 100,
|
||||||
|
"y": normalized_y * 100,
|
||||||
|
"width": normalized_width * 100,
|
||||||
|
"height": normalized_height * 100,
|
||||||
|
}
|
||||||
|
image_annotations.append(annotation)
|
||||||
|
|
||||||
|
annotation_results[image_url] = image_annotations if image_annotations else [] # 保存标注后的图像
|
||||||
|
output_path = os.path.join("./output", os.path.basename(image_url))
|
||||||
|
cv2.imwrite(output_path, image)
|
||||||
|
logger.info("图片路径:{}".format(output_path))
|
||||||
|
|
||||||
|
|
||||||
|
return annotation_results
|
||||||
|
|
||||||
|
# FastAPI端点 - 图像标注接口
|
||||||
|
@app.post("/annotate")
|
||||||
|
async def annotate(request: AnnotateRequest):
|
||||||
|
try:
|
||||||
|
logger.info("开始处理请求: %s", request.dict())
|
||||||
|
annotation_results = annotate_images(
|
||||||
|
model=model,
|
||||||
|
image_urls=request.image_urls,
|
||||||
|
target_labels=request.target_labels,
|
||||||
|
conf_threshold=request.conf_threshold
|
||||||
|
)
|
||||||
|
logger.info("请求处理完成")
|
||||||
|
logger.info("annotations:{}", annotation_results)
|
||||||
|
return {"annotations": annotation_results}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发生错误: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# 示例用法: 调用接口时会返回JSON格式的标注结果,包含图片的路径和标注信息。
|
||||||
Loading…
Reference in New Issue