first commit

This commit is contained in:
weiweiw 2024-12-12 17:05:39 +08:00
commit 1dadf70601
10 changed files with 343 additions and 0 deletions

16
.gitignore vendored Normal file
View File

@ -0,0 +1,16 @@
*.log
*.log.*
*.bak
logs
**/target
.idea/
*.class
*.manifest
*.spec
.DS_Store

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.6 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.4 MiB

Binary file not shown.

BIN
model/sishu_thin_24_12.onnx Normal file

Binary file not shown.

28
readme.md Normal file
View File

@ -0,0 +1,28 @@
## 简介
本项目主要用于实现智能标注支持本地标注的测试和web标注的服务提供。
## model
model 目录用于存放训练好的模型,以及模型对应的配置文件。
## dataset
dataset 目录用于存放测试数据集。
## output
output 目录用于存放含有标注结果的图片。
## logs
logs 目录用于存放标注过程中产生的日志文件。每个日志文件最大10MB最多保留5个文件。
## 文件说明
sishu_label_local.py 为本地标注的测试代码目的是将dataset里的丝束的图片进行标注并将标注结果包括标签和矩形框保存到output目录下。
sishu_label_web.py 为web标注的服务提供代码目的是根据传入的网络图像路径按照传入的标签进行标注返回标注结果。
## 启动sishu_label_local的方法
python sishu_label_local.py
## 启动sishu_label_web的方法
单进程启动:
uvicorn sishu_label_web:app --reload --host '0.0.0.0' --port 8081
多进程启动:
uvicorn sishu_label_web:app --host 0.0.0.0 --port 8081 --workers 5

130
sishu_label_local.py Normal file
View File

@ -0,0 +1,130 @@
import os
from ultralytics import YOLO
import cv2
import random
def generate_label_colors(labels):
"""为每个标签生成固定颜色(红色,黄色,蓝色,绿色)"""
colors = {
"red": (255, 0, 0),
"yellow": (255, 255, 0),
"blue": (0, 0, 255),
"green": (0, 255, 0),
}
# 为标签分配颜色,确保颜色在红、黄、蓝、绿之间循环
label_colors = {}
color_list = list(colors.values())
for i, label in enumerate(labels):
print(f"generate_label_colors, label: {label}")
label_colors[label] = color_list[i % len(color_list)] # 循环分配颜色
return label_colors
def intelligent_annotation_with_filter(model_path, input_folder, output_folder, target_labels, conf_threshold=0.3):
"""
使用 YOLOv8 模型对本地文件夹图像进行智能标注并按指定标签过滤输出
:param model_path: YOLOv8 模型文件路径
:param input_folder: 输入文件夹路径包含待标注的图像
:param output_folder: 输出文件夹路径存储标注后的图像
:param target_labels: 用户指定需要标注的目标类别列表
:param conf_threshold: 置信度阈值默认 0.5
"""
# 加载 YOLO 模型
model = YOLO(model_path)
# 获取所有可能的标签
all_labels = model.names.values()
print("所有可能的标签:", all_labels)
label_colors = generate_label_colors(all_labels)
print("所有可能的标签对应的颜色:", label_colors)
# 保存标注结果的字典
annotation_results = {}
# 创建输出文件夹(如果不存在)
os.makedirs(output_folder, exist_ok=True)
# 遍历输入文件夹中的所有图像
for file_name in os.listdir(input_folder):
input_path = os.path.join(input_folder, file_name)
# 检查文件是否是图像格式
if not (file_name.endswith('.jpg') or file_name.endswith('.png') or file_name.endswith('.jpeg')):
continue
# 模型预测
results = model(input_path, task="detect", conf=conf_threshold)
image_annotations = [] # 当前图像的标注信息
# 读取原始图像
image = cv2.imread(input_path)
# 获取图像尺寸
image_height, image_width = image.shape[:2]
# 标注的标志,检查是否有符合的目标
has_annotation = False
# 遍历预测结果
for result in results:
for box in result.boxes:
# 提取边界框信息和分类标签
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 坐标转换为整数
label = model.names[int(box.cls[0])] # 类别名称
print(f"label:" + label)
conf = box.conf[0] # 置信度
# 检查是否为目标标签
if label in target_labels:
color = label_colors[label] # 获取对应标签的颜色
has_annotation = True
# 绘制边界框
# color = (0, 255, 0) # 绿色
# 随机生成颜色
# color = tuple(random.randint(0, 255) for _ in range(3)) # 生成随机颜色
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
# 在边界框上方绘制标签和置信度
label_text = f"{label} {conf:.2f}"
print(f"label_text:" + label_text)
cv2.putText(image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 2, color, 1)
# 转换为归一化坐标并保存
normalized_x = x1 / image_width
normalized_y = y1 / image_height
normalized_width = (x2 - x1) / image_width
normalized_height = (y2 - y1) / image_height
annotation = {
"label": label,
"x": normalized_x * 100,
"y": normalized_y * 100,
"width": normalized_width * 100,
"height": normalized_height * 100,
}
image_annotations.append(annotation)
# 保存标注信息到结果字典
annotation_results[input_path] = image_annotations if image_annotations else []
# 如果存在符合的标注目标,保存标注后的图像
if has_annotation:
output_path = os.path.join(output_folder, file_name)
cv2.imwrite(output_path, image)
print(f"智能标注完成!标注后的图像保存在:{output_folder}")
return annotation_results
model_path = "./model/sishu_thin_24_12.onnx" # 你的YOLOv8模型权重路径
input_folder = "./dataset" # 输入文件夹路径
output_folder = "./output" # 输出文件夹路径
target_labels = ["tiaojuan", "zhujiesi", "yulinwen"] # 用户指定的目标标签
annotation_results = intelligent_annotation_with_filter(model_path, input_folder, output_folder, target_labels)
print(f"标注结果annotation_results{annotation_results}")
print(f"标注结果:./dataset/1024072305_0724123933.jpg annotation_results{annotation_results['./dataset/1024072305_0724123933.jpg']}")

169
sishu_label_web.py Normal file
View File

@ -0,0 +1,169 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import cv2
from ultralytics import YOLO
import requests
import numpy as np
import time
from contextlib import asynccontextmanager
import logging
import os
from logging.handlers import RotatingFileHandler
from urllib.parse import urlparse
# 设置日志配置
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
# 配置日志
log_file = os.path.join(log_dir, "app.log")
logger = logging.getLogger("Sishu_AI_Annotation")
logger.setLevel(logging.DEBUG)
# 设置文件回滚每个日志文件最大10MB最多保留5个文件
handler = RotatingFileHandler(log_file, maxBytes=10 * 1024 * 1024, backupCount=5)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
@asynccontextmanager
async def app_lifespan(app: FastAPI):
"""定义应用的生命周期事件"""
global model
logger.info("正在加载模型...")
model = YOLO("./model/sishu_thin_24_12.onnx") # 加载YOLO模型
logger.info("模型加载完成")
yield # 生命周期的运行阶段
# 在这里可以执行任何必要的清理操作
# FastAPI应用
app = FastAPI(lifespan=app_lifespan)
# 定义请求参数模型
class AnnotateRequest(BaseModel):
image_urls: List[str] # 图片的网络路径列表
target_labels: Optional[List[str]] = None # 标签列表
conf_threshold: Optional[float] = 0.3 # 模型推理的阈值
# 为每个标签生成固定颜色(红色,黄色,蓝色,绿色)
def generate_label_colors(labels):
colors = {
"red": (255, 0, 0),
"yellow": (255, 255, 0),
"blue": (0, 0, 255),
"green": (0, 255, 0),
}
label_colors = {}
color_list = list(colors.values())
for i, label in enumerate(labels):
label_colors[label] = color_list[i % len(color_list)] # 循环分配颜色
return label_colors
# 通过URL加载图像
def load_image_from_url(image_url):
try:
response = requests.get(image_url)
response.raise_for_status()
image_array = np.frombuffer(response.content, np.uint8)
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
if image is None:
logger.error(f"无法解码图像:{image_url}")
return None
return image
except requests.exceptions.RequestException as e:
logger.error(f"请求失败: {e}")
return None
# 对图像进行智能标注
def annotate_images(model, image_urls, target_labels=None, conf_threshold=0.1):
annotation_results = {}
all_labels = model.names.values()
label_colors = generate_label_colors(all_labels)
for image_url in image_urls:
# 假设 image_path 可能是一个 URL 或者本地路径
if image_url.startswith("http"):
# 如果是URL加载在线图片
start_time = time.time() # 记录开始时间
image = load_image_from_url(image_url)
end_time = time.time() # 记录结束时间
load_time = end_time - start_time
logger.debug(f"加载图像 {image_url} 的时间: {load_time:.3f}")
else:
# 如果是本地路径使用cv2读取
image = cv2.imread(image_url)
if image is None:
logger.error(f"无法读取图像:{image_url}")
continue
image_height, image_width = image.shape[:2]
results = model(image, task="detect", conf=conf_threshold)
image_annotations = []
path = urlparse(image_url).path
file_name = os.path.basename(path)
for result in results:
for box in result.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
label = model.names[int(box.cls[0])]
conf = box.conf[0]
if target_labels is None or label in target_labels:
color = label_colors[label]
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
label_text = f"{label} {conf:.2f}"
cv2.putText(image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
normalized_x = x1 / image_width
normalized_y = y1 / image_height
normalized_width = (x2 - x1) / image_width
normalized_height = (y2 - y1) / image_height
annotation = {
"label": label,
"x": normalized_x * 100,
"y": normalized_y * 100,
"width": normalized_width * 100,
"height": normalized_height * 100,
}
image_annotations.append(annotation)
annotation_results[image_url] = image_annotations if image_annotations else [] # 保存标注后的图像
output_path = os.path.join("./output", os.path.basename(image_url))
cv2.imwrite(output_path, image)
logger.info("图片路径:{}".format(output_path))
return annotation_results
# FastAPI端点 - 图像标注接口
@app.post("/annotate")
async def annotate(request: AnnotateRequest):
try:
logger.info("开始处理请求: %s", request.dict())
annotation_results = annotate_images(
model=model,
image_urls=request.image_urls,
target_labels=request.target_labels,
conf_threshold=request.conf_threshold
)
logger.info("请求处理完成")
logger.info("annotations:{}", annotation_results)
return {"annotations": annotation_results}
except Exception as e:
logger.error(f"发生错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
# 示例用法: 调用接口时会返回JSON格式的标注结果包含图片的路径和标注信息。