first commit
This commit is contained in:
commit
1dadf70601
|
|
@ -0,0 +1,16 @@
|
|||
*.log
|
||||
*.log.*
|
||||
*.bak
|
||||
logs
|
||||
**/target
|
||||
.idea/
|
||||
*.class
|
||||
*.manifest
|
||||
*.spec
|
||||
.DS_Store
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 4.6 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 4.4 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 4.0 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 4.4 MiB |
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,28 @@
|
|||
## 简介
|
||||
本项目主要用于实现智能标注,支持本地标注的测试和web标注的服务提供。’
|
||||
## model
|
||||
model 目录用于存放训练好的模型,以及模型对应的配置文件。
|
||||
|
||||
## dataset
|
||||
dataset 目录用于存放测试数据集。
|
||||
|
||||
## output
|
||||
output 目录用于存放含有标注结果的图片。
|
||||
|
||||
## logs
|
||||
logs 目录用于存放标注过程中产生的日志文件。每个日志文件最大10MB,最多保留5个文件。
|
||||
|
||||
## 文件说明
|
||||
sishu_label_local.py 为本地标注的测试代码,目的是将dataset里的丝束的图片进行标注,并将标注结果包括标签和矩形框保存到output目录下。
|
||||
sishu_label_web.py 为web标注的服务提供代码,目的是根据传入的网络图像路径按照传入的标签进行标注,返回标注结果。
|
||||
|
||||
|
||||
## 启动sishu_label_local的方法
|
||||
python sishu_label_local.py
|
||||
|
||||
## 启动sishu_label_web的方法
|
||||
单进程启动:
|
||||
uvicorn sishu_label_web:app --reload --host '0.0.0.0' --port 8081
|
||||
|
||||
多进程启动:
|
||||
uvicorn sishu_label_web:app --host 0.0.0.0 --port 8081 --workers 5
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
import os
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import random
|
||||
|
||||
def generate_label_colors(labels):
|
||||
"""为每个标签生成固定颜色(红色,黄色,蓝色,绿色)"""
|
||||
colors = {
|
||||
"red": (255, 0, 0),
|
||||
"yellow": (255, 255, 0),
|
||||
"blue": (0, 0, 255),
|
||||
"green": (0, 255, 0),
|
||||
}
|
||||
# 为标签分配颜色,确保颜色在红、黄、蓝、绿之间循环
|
||||
label_colors = {}
|
||||
color_list = list(colors.values())
|
||||
for i, label in enumerate(labels):
|
||||
print(f"generate_label_colors, label: {label}")
|
||||
label_colors[label] = color_list[i % len(color_list)] # 循环分配颜色
|
||||
|
||||
return label_colors
|
||||
|
||||
def intelligent_annotation_with_filter(model_path, input_folder, output_folder, target_labels, conf_threshold=0.3):
|
||||
"""
|
||||
使用 YOLOv8 模型对本地文件夹图像进行智能标注,并按指定标签过滤输出。
|
||||
|
||||
:param model_path: YOLOv8 模型文件路径
|
||||
:param input_folder: 输入文件夹路径,包含待标注的图像
|
||||
:param output_folder: 输出文件夹路径,存储标注后的图像
|
||||
:param target_labels: 用户指定需要标注的目标类别列表
|
||||
:param conf_threshold: 置信度阈值,默认 0.5
|
||||
"""
|
||||
# 加载 YOLO 模型
|
||||
model = YOLO(model_path)
|
||||
|
||||
# 获取所有可能的标签
|
||||
all_labels = model.names.values()
|
||||
print("所有可能的标签:", all_labels)
|
||||
label_colors = generate_label_colors(all_labels)
|
||||
print("所有可能的标签对应的颜色:", label_colors)
|
||||
|
||||
# 保存标注结果的字典
|
||||
annotation_results = {}
|
||||
|
||||
# 创建输出文件夹(如果不存在)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# 遍历输入文件夹中的所有图像
|
||||
for file_name in os.listdir(input_folder):
|
||||
input_path = os.path.join(input_folder, file_name)
|
||||
|
||||
# 检查文件是否是图像格式
|
||||
if not (file_name.endswith('.jpg') or file_name.endswith('.png') or file_name.endswith('.jpeg')):
|
||||
continue
|
||||
|
||||
# 模型预测
|
||||
results = model(input_path, task="detect", conf=conf_threshold)
|
||||
|
||||
image_annotations = [] # 当前图像的标注信息
|
||||
|
||||
# 读取原始图像
|
||||
image = cv2.imread(input_path)
|
||||
|
||||
# 获取图像尺寸
|
||||
image_height, image_width = image.shape[:2]
|
||||
|
||||
# 标注的标志,检查是否有符合的目标
|
||||
has_annotation = False
|
||||
|
||||
# 遍历预测结果
|
||||
for result in results:
|
||||
for box in result.boxes:
|
||||
# 提取边界框信息和分类标签
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 坐标转换为整数
|
||||
label = model.names[int(box.cls[0])] # 类别名称
|
||||
print(f"label:" + label)
|
||||
conf = box.conf[0] # 置信度
|
||||
|
||||
# 检查是否为目标标签
|
||||
if label in target_labels:
|
||||
color = label_colors[label] # 获取对应标签的颜色
|
||||
has_annotation = True
|
||||
|
||||
# 绘制边界框
|
||||
# color = (0, 255, 0) # 绿色
|
||||
# 随机生成颜色
|
||||
# color = tuple(random.randint(0, 255) for _ in range(3)) # 生成随机颜色
|
||||
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
# 在边界框上方绘制标签和置信度
|
||||
label_text = f"{label} {conf:.2f}"
|
||||
print(f"label_text:" + label_text)
|
||||
cv2.putText(image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 2, color, 1)
|
||||
|
||||
# 转换为归一化坐标并保存
|
||||
normalized_x = x1 / image_width
|
||||
normalized_y = y1 / image_height
|
||||
normalized_width = (x2 - x1) / image_width
|
||||
normalized_height = (y2 - y1) / image_height
|
||||
|
||||
annotation = {
|
||||
"label": label,
|
||||
"x": normalized_x * 100,
|
||||
"y": normalized_y * 100,
|
||||
"width": normalized_width * 100,
|
||||
"height": normalized_height * 100,
|
||||
}
|
||||
image_annotations.append(annotation)
|
||||
|
||||
# 保存标注信息到结果字典
|
||||
annotation_results[input_path] = image_annotations if image_annotations else []
|
||||
|
||||
# 如果存在符合的标注目标,保存标注后的图像
|
||||
if has_annotation:
|
||||
output_path = os.path.join(output_folder, file_name)
|
||||
cv2.imwrite(output_path, image)
|
||||
|
||||
print(f"智能标注完成!标注后的图像保存在:{output_folder}")
|
||||
return annotation_results
|
||||
|
||||
model_path = "./model/sishu_thin_24_12.onnx" # 你的YOLOv8模型权重路径
|
||||
input_folder = "./dataset" # 输入文件夹路径
|
||||
output_folder = "./output" # 输出文件夹路径
|
||||
target_labels = ["tiaojuan", "zhujiesi", "yulinwen"] # 用户指定的目标标签
|
||||
annotation_results = intelligent_annotation_with_filter(model_path, input_folder, output_folder, target_labels)
|
||||
print(f"标注结果:annotation_results:{annotation_results}")
|
||||
|
||||
print(f"标注结果:./dataset/1024072305_0724123933.jpg annotation_results:{annotation_results['./dataset/1024072305_0724123933.jpg']}")
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,169 @@
|
|||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
import requests
|
||||
import numpy as np
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import os
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# 设置日志配置
|
||||
log_dir = "logs"
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# 配置日志
|
||||
log_file = os.path.join(log_dir, "app.log")
|
||||
logger = logging.getLogger("Sishu_AI_Annotation")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# 设置文件回滚,每个日志文件最大10MB,最多保留5个文件
|
||||
handler = RotatingFileHandler(log_file, maxBytes=10 * 1024 * 1024, backupCount=5)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
@asynccontextmanager
|
||||
async def app_lifespan(app: FastAPI):
|
||||
"""定义应用的生命周期事件"""
|
||||
global model
|
||||
logger.info("正在加载模型...")
|
||||
model = YOLO("./model/sishu_thin_24_12.onnx") # 加载YOLO模型
|
||||
logger.info("模型加载完成")
|
||||
yield # 生命周期的运行阶段
|
||||
# 在这里可以执行任何必要的清理操作
|
||||
|
||||
# FastAPI应用
|
||||
app = FastAPI(lifespan=app_lifespan)
|
||||
|
||||
# 定义请求参数模型
|
||||
class AnnotateRequest(BaseModel):
|
||||
image_urls: List[str] # 图片的网络路径列表
|
||||
target_labels: Optional[List[str]] = None # 标签列表
|
||||
conf_threshold: Optional[float] = 0.3 # 模型推理的阈值
|
||||
|
||||
# 为每个标签生成固定颜色(红色,黄色,蓝色,绿色)
|
||||
def generate_label_colors(labels):
|
||||
colors = {
|
||||
"red": (255, 0, 0),
|
||||
"yellow": (255, 255, 0),
|
||||
"blue": (0, 0, 255),
|
||||
"green": (0, 255, 0),
|
||||
}
|
||||
label_colors = {}
|
||||
color_list = list(colors.values())
|
||||
for i, label in enumerate(labels):
|
||||
label_colors[label] = color_list[i % len(color_list)] # 循环分配颜色
|
||||
|
||||
return label_colors
|
||||
|
||||
|
||||
# 通过URL加载图像
|
||||
def load_image_from_url(image_url):
|
||||
try:
|
||||
response = requests.get(image_url)
|
||||
response.raise_for_status()
|
||||
image_array = np.frombuffer(response.content, np.uint8)
|
||||
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||
|
||||
if image is None:
|
||||
logger.error(f"无法解码图像:{image_url}")
|
||||
return None
|
||||
|
||||
return image
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"请求失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 对图像进行智能标注
|
||||
def annotate_images(model, image_urls, target_labels=None, conf_threshold=0.1):
|
||||
annotation_results = {}
|
||||
all_labels = model.names.values()
|
||||
label_colors = generate_label_colors(all_labels)
|
||||
|
||||
for image_url in image_urls:
|
||||
# 假设 image_path 可能是一个 URL 或者本地路径
|
||||
if image_url.startswith("http"):
|
||||
# 如果是URL,加载在线图片
|
||||
start_time = time.time() # 记录开始时间
|
||||
image = load_image_from_url(image_url)
|
||||
end_time = time.time() # 记录结束时间
|
||||
load_time = end_time - start_time
|
||||
logger.debug(f"加载图像 {image_url} 的时间: {load_time:.3f} 秒")
|
||||
else:
|
||||
# 如果是本地路径,使用cv2读取
|
||||
image = cv2.imread(image_url)
|
||||
|
||||
if image is None:
|
||||
logger.error(f"无法读取图像:{image_url}")
|
||||
continue
|
||||
|
||||
image_height, image_width = image.shape[:2]
|
||||
results = model(image, task="detect", conf=conf_threshold)
|
||||
|
||||
image_annotations = []
|
||||
|
||||
path = urlparse(image_url).path
|
||||
file_name = os.path.basename(path)
|
||||
|
||||
for result in results:
|
||||
for box in result.boxes:
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||||
label = model.names[int(box.cls[0])]
|
||||
conf = box.conf[0]
|
||||
|
||||
if target_labels is None or label in target_labels:
|
||||
color = label_colors[label]
|
||||
|
||||
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
||||
label_text = f"{label} {conf:.2f}"
|
||||
cv2.putText(image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
||||
|
||||
normalized_x = x1 / image_width
|
||||
normalized_y = y1 / image_height
|
||||
normalized_width = (x2 - x1) / image_width
|
||||
normalized_height = (y2 - y1) / image_height
|
||||
|
||||
annotation = {
|
||||
"label": label,
|
||||
"x": normalized_x * 100,
|
||||
"y": normalized_y * 100,
|
||||
"width": normalized_width * 100,
|
||||
"height": normalized_height * 100,
|
||||
}
|
||||
image_annotations.append(annotation)
|
||||
|
||||
annotation_results[image_url] = image_annotations if image_annotations else [] # 保存标注后的图像
|
||||
output_path = os.path.join("./output", os.path.basename(image_url))
|
||||
cv2.imwrite(output_path, image)
|
||||
logger.info("图片路径:{}".format(output_path))
|
||||
|
||||
|
||||
return annotation_results
|
||||
|
||||
# FastAPI端点 - 图像标注接口
|
||||
@app.post("/annotate")
|
||||
async def annotate(request: AnnotateRequest):
|
||||
try:
|
||||
logger.info("开始处理请求: %s", request.dict())
|
||||
annotation_results = annotate_images(
|
||||
model=model,
|
||||
image_urls=request.image_urls,
|
||||
target_labels=request.target_labels,
|
||||
conf_threshold=request.conf_threshold
|
||||
)
|
||||
logger.info("请求处理完成")
|
||||
logger.info("annotations:{}", annotation_results)
|
||||
return {"annotations": annotation_results}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发生错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# 示例用法: 调用接口时会返回JSON格式的标注结果,包含图片的路径和标注信息。
|
||||
Loading…
Reference in New Issue