commit eae2961d5ba24fe4d1bc44ee424a11c40f5cf029 Author: guanyuankai Date: Thu Jul 31 09:10:03 2025 +0800 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4c15d14 --- /dev/null +++ b/.gitignore @@ -0,0 +1,209 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +#pdm.lock +#pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +#pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml + +#项目文件 +Safety_Detection_Project + +#模型 +*.pt diff --git a/auto_run_yolo_onnx.py b/auto_run_yolo_onnx.py new file mode 100644 index 0000000..b7cfb27 --- /dev/null +++ b/auto_run_yolo_onnx.py @@ -0,0 +1,123 @@ +import yaml +from ultralytics import YOLO +import os +import glob + + +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +def find_latest_run_dir(project_path='runs/detect/default_project'): + """ + 在指定的项目路径下,根据文件夹的修改时间找到最新的 'train' 目录。 + """ + if not os.path.exists(project_path): + return None + train_dirs = [d for d in glob.glob(os.path.join(project_path, 'train*')) if os.path.isdir(d)] + if not train_dirs: + return None + latest_dir = max(train_dirs, key=os.path.getmtime) + return latest_dir + +def run_pipeline(config_path='config.yaml'): + """ + 读取配置文件并执行YOLO训练和导出流程。 + """ + # 1. --- 读取配置文件 --- + try: + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + print("✅ 配置文件加载成功!") + print(f"项目名称: {config['project_name']}") + except Exception as e: + print(f"❌ 加载配置文件时发生错误: {e}") + return + + # 提取核心配置 + base_model = config['base_model'] + data_yaml_path = config['data_yaml'] + project_name = config['project_name'] + + model = YOLO(base_model) + print(f"✅ 模型 '{base_model}' 初始化成功。") + + best_model_path = None + + # 2. --- 执行训练 --- + if config.get('run_training', False): + print("\n🚀 开始YOLO训练...") + try: + results = model.train( + data=data_yaml_path, + epochs=config['epochs'], + imgsz=config['imgsz'], + batch=config['batch_size'], + workers=config['workers'], + project=project_name, + name='train' + ) + + final_results = results[0] if isinstance(results, (list, tuple)) else results + + assert final_results is not None and hasattr(final_results, 'save_dir'), \ + "训练未产生一个包含保存目录的有效结果对象。" + + best_model_path = os.path.join(final_results.save_dir, 'weights/best.pt') + + print(f"✅ 训练完成!最佳模型已保存在: {best_model_path}") + except Exception as e: + print(f"❌ 训练过程中发生错误: {e}") + return + else: + print("\n⏩ 根据配置,跳过训练步骤。") + if config.get('run_export', False): + print(" 正在查找最新的训练结果...") + project_path = os.path.join(ROOT_DIR, project_name) + latest_run_dir = find_latest_run_dir(project_path) + + if latest_run_dir: + potential_model_path = os.path.join(latest_run_dir, 'weights', 'best.pt') + if os.path.exists(potential_model_path): + best_model_path = potential_model_path + print(f"✅ 成功找到最新的模型: {best_model_path}") + else: + print(f"❌ 在最新的训练目录 '{latest_run_dir}' 中未找到 'best.pt' 文件。") + else: + print(f"❌ 在项目 '{project_name}' 中未找到任何过往的训练结果。") + + # 3. --- 执行导出 --- + print(f"root_dir:{ROOT_DIR}") + # print(f"run_export:{config.get('run_export', False)}") + print(f"best model path:{best_model_path}") + if config.get('run_export', False) and best_model_path: + + # --- 修改开始:增加最终的安全检查 --- + # 确保 best_model_path 是一个字符串,而不是元组或列表 + if isinstance(best_model_path, (list, tuple)): + print(f"⚠️ 检测到模型路径为序列类型,自动提取第一个元素。原始值: {best_model_path}") + best_model_path = best_model_path[0] + # --- 修改结束 --- + + print(f"\n🚀 开始将模型 '{best_model_path}' 导出为 {config['export_format']} 格式...") + try: + model_to_export = YOLO(best_model_path) + + model_to_export.export( + format=config['export_format'], + imgsz=config['imgsz'], + half=config.get('half_precision', False) + ) + + exported_file_name = os.path.basename(best_model_path).replace('.pt', f".{config['export_format']}") + exported_file_path = os.path.join(os.path.dirname(best_model_path), exported_file_name) + print(f"✅ 导出成功!文件已保存在: {exported_file_path}") + except Exception as e: + print(f"❌ 导出过程中发生错误: {e}") + elif config.get('run_export', False): + print("\n⏩ 跳过导出步骤,因为未找到有效的模型路径。") + else: + print("\n⏩ 根据配置,跳过导出步骤。") + + print("\n🎉 自动化流程执行完毕!") + + +if __name__ == '__main__': + run_pipeline() \ No newline at end of file diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..a23fdd8 --- /dev/null +++ b/config.yaml @@ -0,0 +1,32 @@ +# =================================================================== +# 自动化训练与导出配置文件 +# =================================================================== + +# --- 项目与模型设置 --- +# 项目名称,所有训练结果将保存在 runs/detect/{project_name} 文件夹下 +project_name: 'Safety_Detection_Project' + +# 基础模型:可以是官方的预训练模型 (如 yolov8n.pt), 也可以是你自己的 .pt 文件路径 +base_model: 'yolo11n.pt' + +# 数据集配置文件:【非常重要】请务必修改为你自己的 data .yaml 文件的绝对或相对路径 +data_yaml: 'coco128.yaml' + +# --- 流程控制 --- +# 是否执行训练步骤 +run_training: false +# 是否执行导出步骤 +run_export: true + +# --- 训练参数 --- +epochs: 100 +imgsz: 640 +batch_size: 16 +# 使用多少个CPU核心进行数据加载,0表示只用主进程 +workers: 8 + +# --- 导出参数 --- +# 目标格式 (e.g., onnx, tensorrt, coreml) +export_format: 'onnx' +# 是否使用半精度(FP16)导出,可以提速并减小文件大小 +half_precision: true \ No newline at end of file