重构AI算法部分

This commit is contained in:
GuanYuankai 2025-10-29 03:12:53 +00:00
parent f9a48f07b7
commit b6315d1790
11 changed files with 627 additions and 386 deletions

View File

@ -68,6 +68,7 @@ add_library(edge_proxy_lib STATIC
src/rknn/preprocess.cc
src/rknn/postprocess.cc
src/videoServiceManager/video_service_manager.cc
src/algorithm/IntrusionModule.cc
)
target_include_directories(edge_proxy_lib PUBLIC

View File

@ -13,38 +13,81 @@
12345
],
"video_service": {
"enabled": true,
"model_path": "/app/edge-proxy/models/RK3588/yolov5s-640-640.rknn",
"streams": [
{
"enabled": true,
"id": "ch1301",
"input_url": "rtsp://admin:hzx12345@192.168.1.10:554/Streaming/Channels/1301",
"output_rtsp": "rtsp://127.0.0.1:8554/ch1301",
"rknn_thread_num": 1
},
{
"enabled": true,
"id": "ch1101",
"input_url": "rtsp://admin:hzx12345@192.168.1.10:554/Streaming/Channels/1101",
"output_rtsp": "rtsp://127.0.0.1:8554/ch1101",
"rknn_thread_num": 1
},
{
"enabled": true,
"id": "ch1401",
"input_url": "rtsp://admin:hzx12345@192.168.1.10:554/Streaming/Channels/1401",
"output_rtsp": "rtsp://127.0.0.1:8554/ch1401",
"rknn_thread_num": 1
},
{
"enabled": true,
"id": "ch1501",
"input_url": "rtsp://admin:hzx12345@192.168.1.10:554/Streaming/Channels/1501",
"output_rtsp": "rtsp://127.0.0.1:8554/ch1501",
"rknn_thread_num": 1
}
]
"enabled": true
},
"video_streams": [
{
"enabled": true,
"id": "cam_01_intrusion",
"input_url": "rtsp://admin:hzx12345@192.168.1.10:554/Streaming/Channels/1301",
"module_config": {
"intrusion_zone": [
100,
100,
300,
300
],
"model_path": "/app/edge-proxy/models/RK3588/yolov5s-640-640.rknn",
"rknn_thread_num": 3,
"time_threshold_sec": 3.0
},
"module_type": "intrusion_detection",
"output_rtsp": "rtsp://127.0.0.1:8554/ch1301"
},
{
"enabled": true,
"id": "cam_02_intrusion",
"input_url": "rtsp://admin:hzx12345@192.168.1.10:554/Streaming/Channels/1101",
"module_config": {
"intrusion_zone": [
100,
100,
300,
300
],
"model_path": "/app/edge-proxy/models/RK3588/yolov5s-640-640.rknn",
"rknn_thread_num": 3,
"time_threshold_sec": 3.0
},
"module_type": "intrusion_detection",
"output_rtsp": "rtsp://127.0.0.1:8554/ch1101"
},
{
"enabled": true,
"id": "cam_03_intrusion",
"input_url": "rtsp://admin:hzx12345@192.168.1.10:554/Streaming/Channels/1501",
"module_config": {
"intrusion_zone": [
100,
100,
300,
300
],
"model_path": "/app/edge-proxy/models/RK3588/yolov5s-640-640.rknn",
"rknn_thread_num": 3,
"time_threshold_sec": 3.0
},
"module_type": "intrusion_detection",
"output_rtsp": "rtsp://127.0.0.1:8554/ch1501"
},
{
"enabled": true,
"id": "cam_041_intrusion",
"input_url": "rtsp://admin:hzx12345@192.168.1.10:554/Streaming/Channels/1401",
"module_config": {
"intrusion_zone": [
100,
100,
300,
300
],
"model_path": "/app/edge-proxy/models/RK3588/yolov5s-640-640.rknn",
"rknn_thread_num": 3,
"time_threshold_sec": 3.0
},
"module_type": "intrusion_detection",
"output_rtsp": "rtsp://127.0.0.1:8554/ch1401"
}
],
"web_server_port": 8080
}

View File

@ -0,0 +1,27 @@
// IAnalysisModule.h
#pragma once
#include <opencv2/core/core.hpp>
#include <string>
/**
* @brief AI分析模块的抽象基类
*
* AI视频分析模块
* VideoService AI模块交互
*/
class IAnalysisModule {
public:
virtual ~IAnalysisModule() = default;
/**
* @brief ()
*/
virtual bool init() = 0;
/**
* @brief ()
* @param frame [in/out]
*/
virtual bool process(cv::Mat& frame) = 0;
};

View File

@ -0,0 +1,222 @@
// IntrusionModule.cc
#include "IntrusionModule.h"
#include "spdlog/spdlog.h"
#include "opencv2/imgproc/imgproc.hpp"
#include <stdio.h>
IntrusionModule::IntrusionModule(std::string model_path,
int thread_num,
cv::Rect intrusion_zone,
double intrusion_time_threshold)
: model_path_(model_path),
thread_num_(thread_num),
intrusion_zone_(intrusion_zone),
intrusion_time_threshold_(intrusion_time_threshold),
next_track_id_(1) //
{
spdlog::info("[IntrusionModule] Created. Model: {}, Threads: {}", model_path_, thread_num_);
if (intrusion_zone_.width <= 0 || intrusion_zone_.height <= 0) {
spdlog::warn("[IntrusionModule] Warning: Intrusion zone is invalid (0,0,0,0). It will be set at runtime.");
}
}
//
// init() 函数: 从 video_service.cc 的 start() 移动而来
//
bool IntrusionModule::init() {
rknn_pool_ = std::make_unique<rknnPool<rkYolov5s, cv::Mat, detect_result_group_t>>(model_path_.c_str(), thread_num_);
if (rknn_pool_->init() != 0) {
spdlog::error("[IntrusionModule] rknnPool init fail!");
return false;
}
spdlog::info("[IntrusionModule] rknnPool init success.");
return true;
}
//
// process() 函数: 从 video_service.cc 的 processing_loop() 移动而来
//
bool IntrusionModule::process(cv::Mat& frame) {
if (frame.empty()) {
return false;
}
// 1. 图像预处理 (来自)
cv::Mat model_input_image;
cv::resize(frame, model_input_image, cv::Size(640, 640));
if (!model_input_image.isContinuous()) {
model_input_image = model_input_image.clone();
}
// 2. RKNN 推理 (来自)
if (rknn_pool_->put(model_input_image) != 0) {
spdlog::error("[IntrusionModule] Failed to put frame into rknnPool.");
return false;
}
detect_result_group_t detection_results;
if (rknn_pool_->get(detection_results) != 0) {
spdlog::error("[IntrusionModule] Failed to get frame from rknnPool.");
return false;
}
// 3. 跟踪与报警 (来自)
this->update_tracker(detection_results, frame.size());
// 4. 绘制结果 (来自)
this->draw_results(frame); // 直接在传入的 frame 上绘制
return true;
}
//
// 以下所有函数均从 video_service.cc 完整剪切而来
//
void IntrusionModule::trigger_alarm(int person_id, const cv::Rect& box) {
printf("[ALARM] Intrusion detected! Person ID: %d at location (%d, %d, %d, %d)\n",
person_id, box.x, box.y, box.width, box.height);
// TODO: 在这里实现真正的报警逻辑,例如发送网络消息、写入数据库等。
}
double IntrusionModule::get_current_time_seconds() {
return std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now().time_since_epoch()
).count();
}
void IntrusionModule::update_tracker(detect_result_group_t &detect_result_group, const cv::Size& frame_size)
{
// 如果入侵区域无效,则设置为帧中心的 1/4 区域 (基于原始帧大小)
if (intrusion_zone_.width <= 0 || intrusion_zone_.height <= 0) {
intrusion_zone_ = cv::Rect(frame_size.width / 4, frame_size.height / 4, frame_size.width / 2, frame_size.height / 2);
}
// --- 缩放比例计算 ---
const float model_input_width = 640.0f;
const float model_input_height = 640.0f;
float scale_x = (float)frame_size.width / model_input_width;
float scale_y = (float)frame_size.height / model_input_height;
std::vector<cv::Rect> current_detections;
for (int i = 0; i < detect_result_group.count; i++) {
detect_result_t *det_result = &(detect_result_group.results[i]);
if (strcmp(det_result->name, "person") == 0) {
int original_left = static_cast<int>(det_result->box.left * scale_x);
int original_top = static_cast<int>(det_result->box.top * scale_y);
int original_right = static_cast<int>(det_result->box.right * scale_x);
int original_bottom = static_cast<int>(det_result->box.bottom * scale_y);
original_left = std::max(0, std::min(original_left, frame_size.width - 1));
original_top = std::max(0, std::min(original_top, frame_size.height - 1));
original_right = std::max(original_left, std::min(original_right, frame_size.width));
original_bottom = std::max(original_top, std::min(original_bottom, frame_size.height));
if (original_right > original_left && original_bottom > original_top) {
current_detections.push_back(cv::Rect(
original_left, original_top,
original_right - original_left,
original_bottom - original_top
));
}
}
}
for (auto it = tracked_persons_.begin(); it != tracked_persons_.end(); ++it) {
it->second.frames_unseen++;
}
std::vector<int> matched_track_ids;
for (const auto& det_box : current_detections) {
int best_match_id = -1;
double max_iou_threshold = 0.3;
double best_iou = 0.0;
for (auto const& [id, person] : tracked_persons_) {
bool already_matched = false;
for(int matched_id : matched_track_ids) {
if (id == matched_id) {
already_matched = true;
break;
}
}
if (already_matched) {
continue;
}
double iou = (double)(det_box & person.box).area() / (double)(det_box | person.box).area();
if (iou > best_iou && iou >= max_iou_threshold) {
best_iou = iou;
best_match_id = id;
}
}
if (best_match_id != -1) {
tracked_persons_[best_match_id].box = det_box;
tracked_persons_[best_match_id].frames_unseen = 0;
matched_track_ids.push_back(best_match_id);
} else {
TrackedPerson new_person;
new_person.id = next_track_id_++;
new_person.box = det_box;
new_person.entry_time = 0;
new_person.is_in_zone = false;
new_person.alarm_triggered = false;
new_person.frames_unseen = 0;
tracked_persons_[new_person.id] = new_person;
}
}
double current_time = get_current_time_seconds();
for (auto it = tracked_persons_.begin(); it != tracked_persons_.end(); ++it) {
TrackedPerson& person = it->second;
bool currently_in_zone = (intrusion_zone_ & person.box).area() > 0;
if (currently_in_zone) {
if (!person.is_in_zone) {
person.is_in_zone = true;
person.entry_time = current_time;
person.alarm_triggered = false;
} else {
if (!person.alarm_triggered && (current_time - person.entry_time) >= intrusion_time_threshold_) {
person.alarm_triggered = true;
trigger_alarm(person.id, person.box);
}
}
} else {
if (person.is_in_zone) {
person.is_in_zone = false;
person.entry_time = 0;
person.alarm_triggered = false;
}
}
}
for (auto it = tracked_persons_.begin(); it != tracked_persons_.end(); /* 无自增 */) {
if (it->second.frames_unseen > 50) {
it = tracked_persons_.erase(it);
} else {
++it;
}
}
}
void IntrusionModule::draw_results(cv::Mat& frame)
{
cv::rectangle(frame, this->intrusion_zone_, cv::Scalar(255, 255, 0), 2); // 黄色
for (auto const& [id, person] : this->tracked_persons_) {
cv::Scalar box_color = person.alarm_triggered ? cv::Scalar(0, 0, 255) : cv::Scalar(0, 255, 0);
int line_thickness = person.alarm_triggered ? 3 : 2;
cv::rectangle(frame, person.box, box_color, line_thickness);
std::string label = "Person " + std::to_string(id);
if (person.is_in_zone) {
label += " (In Zone)";
}
cv::putText(frame, label, cv::Point(person.box.x, person.box.y - 10),
cv::FONT_HERSHEY_SIMPLEX, 0.5, box_color, 2);
}
}

View File

@ -0,0 +1,63 @@
// IntrusionModule.h
#pragma once
#include "IAnalysisModule.h"
#include "rknn/postprocess.h"
#include "rknn/rkYolov5s.hpp"
#include "rknn/rknnPool.hpp"
#include <opencv2/core/core.hpp>
#include <map>
#include <string>
#include <memory>
#include <chrono>
#include <algorithm>
struct TrackedPerson
{
int id;
cv::Rect box;
double entry_time;
bool is_in_zone;
bool alarm_triggered;
int frames_unseen;
};
class IntrusionModule : public IAnalysisModule {
public:
/**
* @brief
* @param model_path rknn模型文件路径
* @param thread_num rknn线程池数量
* @param intrusion_zone
* @param intrusion_time_threshold
*/
IntrusionModule(std::string model_path,
int thread_num,
cv::Rect intrusion_zone,
double intrusion_time_threshold);
virtual ~IntrusionModule() = default;
// --- 实现 IAnalysisModule 接口 ---
virtual bool init() override;
virtual bool process(cv::Mat& frame) override;
private:
// --- 以下函数从 video_service.cc 移动到这里 ---
void update_tracker(detect_result_group_t &detect_result_group, const cv::Size& frame_size);
void draw_results(cv::Mat& frame);
void trigger_alarm(int person_id, const cv::Rect& box);
double get_current_time_seconds();
// --- 以下成员变量从 video_service.h 移动到这里 ---
std::string model_path_;
int thread_num_;
std::unique_ptr<rknnPool<rkYolov5s, cv::Mat, detect_result_group_t>> rknn_pool_;
cv::Rect intrusion_zone_;
std::map<int, TrackedPerson> tracked_persons_;
int next_track_id_;
double intrusion_time_threshold_;
};

View File

@ -1,11 +1,14 @@
// config_manager.cc (修改后)
#include "config_manager.h"
#include <fstream>
ConfigManager& ConfigManager::getInstance() {
static ConfigManager instance;
return instance;
}
json ConfigManager::createDefaultConfig() {
// --- 修改: 添加新的 video_service 和 video_streams 默认值 ---
return json {
{"device_id", "default-edge-proxy-01"},
{"config_base_path", "/app/config/"},
@ -18,9 +21,31 @@ json ConfigManager::createDefaultConfig() {
{"log_level", "debug"},
{"alarm_rules_path", "alarms.json"},
{"piper_executable_path", "/usr/bin/piper"},
{"piper_model_path", "/app/models/model.onnx"}
{"piper_model_path", "/app/models/model.onnx"},
// **新特性**:您可以在这里添加任何新的默认值
// --- 新增: 视频服务配置 ---
{
"video_service", {
{"enabled", false}
}
},
{
"video_streams", {
{
{"id", "cam_01_example"},
{"enabled", false},
{"input_url", "rtsp://your_camera_stream"},
{"output_rtsp", "rtsp://localhost:8554/cam_01_out"},
{"module_type", "intrusion_detection"},
{"module_config", {
{"model_path", "/app/models/yolov5s.rknn"},
{"rknn_thread_num", 3},
{"intrusion_zone", {0, 0, 1920, 1080}},
{"time_threshold_sec", 5.0}
}}
}
}
}
};
}
@ -58,10 +83,10 @@ bool ConfigManager::load(const std::string& configFilePath) {
try {
ifs >> m_config_json;
// **重要**:合并默认值。确保JSON文件中缺失的键被默认值补全。
// **重要**:合并默认值。
json defaults = createDefaultConfig();
defaults.merge_patch(m_config_json);
m_config_json = defaults;
defaults.merge_patch(m_config_json); // <-- 关键: m_config_json 会覆盖 defaults
m_config_json = defaults; // 将合并后的结果存回
spdlog::info("Successfully loaded config from '{}'. Device ID: {}", m_configFilePath, m_config_json.value("device_id", "N/A"));
@ -76,67 +101,52 @@ bool ConfigManager::load(const std::string& configFilePath) {
}
}
// ... (save, getDeviceID, ... getPiperModelPath 保持不变) ...
bool ConfigManager::save() {
std::unique_lock<std::shared_mutex> lock(m_mutex);
return save_unlocked();
}
std::string ConfigManager::getDeviceID() {
return get<std::string>("device_id", "default-edge-proxy-01");
}
std::string ConfigManager::getConfigBasePath() {
return get<std::string>("config_base_path", "/app/config/");
}
std::string ConfigManager::getMqttBroker() {
return get<std::string>("mqtt_broker", "tcp://localhost:1883");
}
std::string ConfigManager::getMqttClientID() {
return get<std::string>("mqtt_client_id_prefix", "edge-proxy-") + getDeviceID();
}
std::string ConfigManager::getDataStorageDbPath() {
return getConfigBasePath() + get<std::string>("data_storage_db_path", "edge_proxy_data.db");
}
std::string ConfigManager::getDataCacheDbPath() {
return getConfigBasePath() + get<std::string>("data_cache_db_path", "edge_data_cache.db");
}
std::string ConfigManager::getDevicesConfigPath() {
return getConfigBasePath() + "devices.json";
}
int ConfigManager::getWebServerPort() {
return get<int>("web_server_port", 8080);
}
std::vector<uint16_t> ConfigManager::getTcpServerPorts() {
return get<std::vector<uint16_t>>("tcp_server_ports", {12345});
}
std::string ConfigManager::getLogLevel() {
return get<std::string>("log_level", "debug");
}
std::string ConfigManager::getAlarmRulesPath() {
return getConfigBasePath() + get<std::string>("alarm_rules_path", "alarms.json");
}
std::string ConfigManager::getPiperExecutablePath() {
return get<std::string>("piper_executable_path", "/usr/bin/piper");
}
std::string ConfigManager::getPiperModelPath() {
return get<std::string>("piper_model_path", "/app/models/model.onnx");
}
// --- (getIsVideoServiceEnabled 保持不变) ---
bool ConfigManager::getIsVideoServiceEnabled() const {
std::shared_lock<std::shared_mutex> lock(m_mutex);
try {
@ -149,44 +159,47 @@ bool ConfigManager::getIsVideoServiceEnabled() const {
return false;
}
std::string ConfigManager::getVideoModelPath() const {
std::shared_lock<std::shared_mutex> lock(m_mutex);
try {
if (m_config_json.contains("video_service")) {
return m_config_json["video_service"].value("model_path", "");
}
} catch (const json::type_error& e) {
spdlog::warn("Config type mismatch for key 'video_service.model_path'. Error: {}", e.what());
}
return "";
}
// --- 移除: getVideoModelPath ---
// std::string ConfigManager::getVideoModelPath() const { ... }
// --- 修改: getVideoStreamConfigs ---
std::vector<ConfigManager::VideoStreamConfig> ConfigManager::getVideoStreamConfigs() const {
std::vector<VideoStreamConfig> configs;
std::shared_lock<std::shared_mutex> lock(m_mutex);
try {
if (m_config_json.contains("video_service") &&
m_config_json["video_service"].contains("streams") &&
m_config_json["video_service"]["streams"].is_array())
// --- 修改: 路径变为顶层的 "video_streams" ---
if (m_config_json.contains("video_streams") &&
m_config_json["video_streams"].is_array())
{
for (const auto& stream_json : m_config_json["video_service"]["streams"]) {
for (const auto& stream_json : m_config_json["video_streams"]) {
VideoStreamConfig cfg;
cfg.id = stream_json.value("id", "");
cfg.enabled = stream_json.value("enabled", false);
cfg.input_url = stream_json.value("input_url", "");
cfg.output_rtsp = stream_json.value("output_rtsp", "");
cfg.rknn_thread_num = stream_json.value("rknn_thread_num", 1);
// --- 移除 ---
// cfg.rknn_thread_num = stream_json.value("rknn_thread_num", 1);
// --- 新增 ---
cfg.module_type = stream_json.value("module_type", "");
cfg.module_config = stream_json.value("module_config", json::object()); // 传递整个json对象
if (cfg.module_type.empty()) {
spdlog::warn("Video stream '{}' has no 'module_type' defined. It may fail to start.", cfg.id);
}
configs.push_back(cfg);
}
} else {
spdlog::warn("Config key 'video_service.streams' not found or is not an array.");
spdlog::warn("Config key 'video_streams' not found or is not an array.");
}
} catch (const json::exception& e) {
// 捕获所有可能的 JSON (Source 2) 解析异常
spdlog::error("Error parsing 'video_service.streams': {}", e.what());
// 捕获所有可能的 JSON 解析异常
spdlog::error("Error parsing 'video_streams': {}", e.what());
}
return configs;

View File

@ -1,3 +1,4 @@
// config_manager.h (修改后)
#pragma once
#include <string>
@ -11,25 +12,24 @@ using json = nlohmann::json;
class ConfigManager {
public:
//
// --- 核心修改 ---
//
struct VideoStreamConfig {
std::string id;
bool enabled;
std::string input_url;
std::string output_rtsp;
int rknn_thread_num;
// int rknn_thread_num; // <-- 移除 (已移动到 module_config)
std::string module_type; // <-- 新增: "intrusion_detection" 或 "face_recognition"
json module_config; // <-- 新增: 传递模块的特定配置 (最灵活的方式)
};
static ConfigManager& getInstance();
bool load(const std::string& configFilePath);
bool save();
/**
* @brief [] 线 GET
* @tparam T (e.g., std::string, int, bool)
* @param key JSON中的键
* @param default_value
* @return T
*/
template<typename T>
T get(const std::string& key, const T& default_value) {
std::shared_lock<std::shared_mutex> lock(m_mutex);
@ -51,12 +51,6 @@ public:
}
}
/**
* @brief [] 线 SET
* @tparam T
* @param key
* @param value
*/
template<typename T>
void set(const std::string& key, const T& value) {
{
@ -67,7 +61,6 @@ public:
save();
// **特殊处理**: 某些配置需要立即生效
if (key == "log_level") {
spdlog::set_level(spdlog::level::from_str(value));
}
@ -89,8 +82,7 @@ public:
std::string getPiperModelPath();
bool getIsVideoServiceEnabled() const;
std::string getVideoModelPath() const;
std::vector<VideoStreamConfig> getVideoStreamConfigs() const;
std::vector<VideoStreamConfig> getVideoStreamConfigs() const; // (签名不变, 实现改变)
private:
ConfigManager() = default;
@ -104,7 +96,4 @@ private:
std::string m_configFilePath;
json m_config_json;
mutable std::shared_mutex m_mutex;
};

View File

@ -1,39 +1,33 @@
// video_service.cpp
// video_service.cc (修改后)
#include "video_service.h"
#include <stdio.h>
#include "opencv2/imgproc/imgproc.hpp"
#include "rknn/rkYolov5s.hpp"
#include "rknn/rknnPool.hpp"
// #include "rknn/rkYolov5s.hpp" // <-- 移除
// #include "rknn/rknnPool.hpp" // <-- 移除
#include "spdlog/spdlog.h"
#include <chrono>
#include <algorithm>
void VideoService::trigger_alarm(int person_id, const cv::Rect& box) {
printf("[ALARM] Intrusion detected! Person ID: %d at location (%d, %d, %d, %d)\n",
person_id, box.x, box.y, box.width, box.height);
// TODO: 在这里实现真正的报警逻辑,例如发送网络消息、写入数据库等。
}
// #include <chrono> // <-- 移除 (已移至 IntrusionModule)
// #include <algorithm> // <-- 移除 (已移至 IntrusionModule)
double VideoService::get_current_time_seconds() {
return std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now().time_since_epoch()
).count();
}
//
// !!! 关键: trigger_alarm, get_current_time_seconds, update_tracker, draw_results
// !!! 所有这些函数 都已被剪切并移动到 IntrusionModule.cc
//
VideoService::VideoService(std::string model_path,
int thread_num,
// 构造函数:修改为接收 module
VideoService::VideoService(std::unique_ptr<IAnalysisModule> module,
std::string input_url,
std::string output_rtsp_url)
: model_path_(model_path),
thread_num_(thread_num),
: module_(std::move(module)), // <-- 关键:接收模块所有权
input_url_(input_url),
output_rtsp_url_(output_rtsp_url),
running_(false)
{
log_prefix_ = "[VideoService: " + input_url + "]";
next_track_id_ = 1;
intrusion_time_threshold_ = 3.0; // 3秒
intrusion_zone_ = cv::Rect(0, 0, 0, 0); // 默认无效
// !!! 移除所有AI相关的初始化 !!!
// next_track_id_ = 1;
// intrusion_time_threshold_ = 3.0;
// intrusion_zone_ = cv::Rect(0, 0, 0, 0);
spdlog::info("{} Created. Input: {}, Output: {}", log_prefix_, input_url_.c_str(), output_rtsp_url_.c_str());
}
@ -46,21 +40,24 @@ VideoService::~VideoService() {
bool VideoService::start() {
rknn_pool_ = std::make_unique<rknnPool<rkYolov5s, cv::Mat, detect_result_group_t>>(model_path_.c_str(), thread_num_);
if (rknn_pool_->init() != 0) {
printf("rknnPool init fail!\n");
// 1. (修改) 初始化AI模块
if (!module_ || !module_->init()) {
spdlog::error("{} Failed to initialize analysis module!", log_prefix_);
return false;
}
printf("rknnPool init success.\n");
spdlog::info("{} Analysis module initialized successfully.", log_prefix_);
// setenv("OPENCV_FFMPEG_CAPTURE_OPTIONS", "rtsp_transport;tcp", 1);
// printf("Set RTSP transport protocol to TCP\n");
// 2. (移除) RKNN Pool 初始化
// rknn_pool_ = std::make_unique<...>();
// ...
// 3. (不变) GStreamer 输入管线初始化
std::string gst_input_pipeline =
"rtspsrc location=" + input_url_ + " latency=0 protocols=tcp ! "
"rtph265depay ! "
"h265parse ! "
"mppvideodec format=16 ! "
"video/x-raw,format=BGR ! " // <-- 关键:直接请求 mppvideodec 输出 BGR 格式
"video/x-raw,format=BGR ! "
"appsink";
spdlog::info("Try to Open RTSP Stream");
@ -78,16 +75,43 @@ bool VideoService::start() {
frame_fps_ = capture_.get(cv::CAP_PROP_FPS);
if (frame_fps_ <= 0) frame_fps_ = 25.0;
if (frame_width_ == 0 || frame_height_ == 0) {
spdlog::error("{} Failed to get valid frame width or height from GStreamer pipeline (got {}x{}).",
log_prefix_, frame_width_, frame_height_);
spdlog::error("{} This usually means the RTSP stream is unavailable or the GStreamer input pipeline (mppvideodec?) failed.",
log_prefix_);
cv::Mat test_frame;
if (capture_.read(test_frame) && !test_frame.empty()) {
frame_width_ = test_frame.cols;
frame_height_ = test_frame.rows;
spdlog::info("{} Successfully got frame size by reading first frame: {}x{}",
log_prefix_, frame_width_, frame_height_);
{
std::lock_guard<std::mutex> lock(frame_mutex_);
latest_frame_ = test_frame;
new_frame_available_ = true;
}
frame_cv_.notify_one();
} else {
spdlog::error("{} Failed to read first frame to determine size. Aborting.", log_prefix_);
capture_.release();
return false; // 提前中止
}
}
printf("RTSP stream opened successfully! (%dx%d @ %.2f FPS)\n", frame_width_, frame_height_, frame_fps_);
// 4. (不变) GStreamer 输出管线初始化
std::string gst_pipeline =
"appsrc ! "
"queue max-size-buffers=2 leaky=downstream ! "
"video/x-raw,format=BGR ! " // OpenCV VideoWriter 输入 BGR 数据
"videoconvert ! " // <-- 使用 CPU 将 BGR 转换为 NV12
"video/x-raw,format=NV12 ! " // 明确指定 videoconvert 输出 NV12
"mpph265enc gop=25 rc-mode=fixqp qp-init=26 ! " // 硬件编码器接收 NV12 数据
"video/x-raw,format=BGR ! "
"videoconvert ! "
"video/x-raw,format=NV12 ! "
"mpph265enc gop=25 rc-mode=fixqp qp-init=26 ! "
"h265parse ! "
"rtspclientsink location=" + output_rtsp_url_ + " latency=0 protocols=tcp";
@ -95,8 +119,7 @@ bool VideoService::start() {
writer_.open(gst_pipeline,
cv::CAP_GSTREAMER,
0,
frame_fps_,
0, frame_fps_,
cv::Size(frame_width_, frame_height_),
true);
@ -107,6 +130,7 @@ bool VideoService::start() {
}
printf("VideoWriter opened successfully.\n");
// 5. (不变) 启动线程
running_ = true;
reading_thread_ = std::thread(&VideoService::reading_loop, this);
processing_thread_ = std::thread(&VideoService::processing_loop, this);
@ -120,6 +144,9 @@ void VideoService::stop() {
printf("Stopping VideoService...\n");
running_ = false;
// 唤醒可能在 frame_cv_.wait() 处等待的线程
frame_cv_.notify_all();
if (reading_thread_.joinable()) {
reading_thread_.join();
}
@ -136,145 +163,16 @@ void VideoService::stop() {
writer_.release();
}
// (可选) 确保模块资源被释放 (虽然unique_ptr析构时会自动处理)
module_.reset();
printf("VideoService stopped.\n");
}
void VideoService::update_tracker(detect_result_group_t &detect_result_group, const cv::Size& frame_size)
{
// 如果入侵区域无效,则设置为帧中心的 1/4 区域 (基于原始帧大小)
if (intrusion_zone_.width <= 0 || intrusion_zone_.height <= 0) {
intrusion_zone_ = cv::Rect(frame_size.width / 4, frame_size.height / 4, frame_size.width / 2, frame_size.height / 2);
}
// --- 缩放比例计算 ---
// !!! 重要: 请确保这里的 640.0f 与您 OpenCV resize 时的目标尺寸一致 !!!
const float model_input_width = 640.0f;
const float model_input_height = 640.0f;
float scale_x = (float)frame_size.width / model_input_width;
float scale_y = (float)frame_size.height / model_input_height;
// --- 结束缩放比例计算 ---
std::vector<cv::Rect> current_detections; // 存储当前帧检测到的、已缩放到原始尺寸的框
for (int i = 0; i < detect_result_group.count; i++) {
detect_result_t *det_result = &(detect_result_group.results[i]);
// 只处理 "person" 类别
if (strcmp(det_result->name, "person") == 0) {
// --- 将模型输出坐标按比例缩放回原始帧坐标 ---
int original_left = static_cast<int>(det_result->box.left * scale_x);
int original_top = static_cast<int>(det_result->box.top * scale_y);
int original_right = static_cast<int>(det_result->box.right * scale_x);
int original_bottom = static_cast<int>(det_result->box.bottom * scale_y);
// --- 边界检查与修正 ---
// 确保坐标不会超出原始图像边界
original_left = std::max(0, std::min(original_left, frame_size.width - 1));
original_top = std::max(0, std::min(original_top, frame_size.height - 1));
// 确保 right >= left, bottom >= top
original_right = std::max(original_left, std::min(original_right, frame_size.width));
original_bottom = std::max(original_top, std::min(original_bottom, frame_size.height));
// --- 结束边界检查 ---
// 只有当框有效时宽度和高度大于0才添加到检测列表
if (original_right > original_left && original_bottom > original_top) {
current_detections.push_back(cv::Rect(
original_left, original_top,
original_right - original_left, // width
original_bottom - original_top // height
));
}
// --- 结束坐标缩放 ---
}
}
for (auto it = tracked_persons_.begin(); it != tracked_persons_.end(); ++it) {
it->second.frames_unseen++;
}
std::vector<int> matched_track_ids; // 记录本帧已匹配到的跟踪ID防止一个跟踪ID匹配多个检测框
// 尝试将当前检测框与已跟踪目标进行匹配
for (const auto& det_box : current_detections) {
int best_match_id = -1;
double max_iou_threshold = 0.3; // IoU 匹配阈值
double best_iou = 0.0; // 用于寻找最佳匹配
for (auto const& [id, person] : tracked_persons_) {
// 检查该跟踪ID是否已在本帧匹配过
bool already_matched = false;
for(int matched_id : matched_track_ids) {
if (id == matched_id) {
already_matched = true;
break;
}
}
if (already_matched) {
continue; // 跳过已匹配的跟踪目标
}
// 计算 IoU (现在 det_box 和 person.box 都是原始坐标系)
double iou = (double)(det_box & person.box).area() / (double)(det_box | person.box).area();
if (iou > best_iou && iou >= max_iou_threshold) { // 必须大于等于阈值才能成为候选
best_iou = iou;
best_match_id = id;
}
}
if (best_match_id != -1) {
// 找到匹配,更新跟踪信息
tracked_persons_[best_match_id].box = det_box;
tracked_persons_[best_match_id].frames_unseen = 0;
matched_track_ids.push_back(best_match_id); // 记录已匹配
} else {
// 没有找到匹配,创建新的跟踪目标
TrackedPerson new_person;
new_person.id = next_track_id_++; // 分配新ID
new_person.box = det_box;
new_person.entry_time = 0;
new_person.is_in_zone = false;
new_person.alarm_triggered = false;
new_person.frames_unseen = 0;
tracked_persons_[new_person.id] = new_person;
}
}
// 更新每个跟踪目标的区域状态和报警状态
double current_time = get_current_time_seconds();
for (auto it = tracked_persons_.begin(); it != tracked_persons_.end(); ++it) {
TrackedPerson& person = it->second;
// 检查人与入侵区域的交集 (现在都是原始坐标系)
bool currently_in_zone = (intrusion_zone_ & person.box).area() > 0;
if (currently_in_zone) {
if (!person.is_in_zone) { // 刚进入区域
person.is_in_zone = true;
person.entry_time = current_time;
person.alarm_triggered = false; // 重置报警状态
} else { // 持续在区域内
// 检查是否达到报警时间阈值且尚未报警
if (!person.alarm_triggered && (current_time - person.entry_time) >= intrusion_time_threshold_) {
person.alarm_triggered = true;
trigger_alarm(person.id, person.box); // 触发报警
}
}
} else { // 当前不在区域内
if (person.is_in_zone) { // 刚离开区域
person.is_in_zone = false;
person.entry_time = 0; // 重置进入时间
person.alarm_triggered = false; // 重置报警状态
}
}
}
for (auto it = tracked_persons_.begin(); it != tracked_persons_.end(); /* 无自增 */) {
if (it->second.frames_unseen > 50) { // 超过 50 帧未见则移除
it = tracked_persons_.erase(it); // erase 返回下一个有效迭代器
} else {
++it; // 手动移动到下一个
}
}
}
//
// reading_loop() 函数: 完全不变
//
void VideoService::reading_loop() {
cv::Mat frame;
spdlog::info("Reading thread started.");
@ -299,35 +197,21 @@ void VideoService::reading_loop() {
frame_cv_.notify_one();
}
frame_cv_.notify_all();
frame_cv_.notify_all(); // 确保 processing_loop 也会退出
spdlog::info("Reading loop finished.");
}
void VideoService::draw_results(cv::Mat& frame)
{
cv::rectangle(frame, this->intrusion_zone_, cv::Scalar(255, 255, 0), 2); // 黄色
for (auto const& [id, person] : this->tracked_persons_) {
cv::Scalar box_color = person.alarm_triggered ? cv::Scalar(0, 0, 255) : cv::Scalar(0, 255, 0);
int line_thickness = person.alarm_triggered ? 3 : 2;
cv::rectangle(frame, person.box, box_color, line_thickness);
std::string label = "Person " + std::to_string(id);
if (person.is_in_zone) {
label += " (In Zone)";
}
cv::putText(frame, label, cv::Point(person.box.x, person.box.y - 10),
cv::FONT_HERSHEY_SIMPLEX, 0.5, box_color, 2);
}
}
//
// processing_loop() 函数: 极大简化 (关键修改)
//
void VideoService::processing_loop() {
cv::Mat frame;
detect_result_group_t detection_results;
// detect_result_group_t detection_results; // <-- 移除
while (running_) {
{
// 1. (不变) 获取帧
std::unique_lock<std::mutex> lock(frame_mutex_);
frame_cv_.wait(lock, [&]{
@ -338,55 +222,29 @@ void VideoService::processing_loop() {
break;
}
// 确认是有新帧
frame = latest_frame_.clone();
new_frame_available_ = false;
}
if (frame.empty()) {
continue;
}
cv::Mat model_input_image;
cv::resize(frame, model_input_image, cv::Size(640, 640));
if (!model_input_image.isContinuous()) {
model_input_image = model_input_image.clone();
}
if (rknn_pool_->put(model_input_image) != 0) {
spdlog::error("VideoService: Failed to put frame into rknnPool. Stopping.");
running_ = false;
break;
// 2. (关键修改) 调用AI模块处理
// --- 移除所有 resize, put, get, update_tracker, draw_results ---
if (!module_->process(frame)) {
// 模块报告处理失败
spdlog::warn("{} Module failed to process frame. Skipping.", log_prefix_);
}
// 此时 'frame' 已经被 module_->process() 修改(例如,绘制了框)
if (rknn_pool_->get(detection_results) != 0) {
spdlog::error("VideoService: Failed to get frame from rknnPool. Stopping.");
running_ = false;
break;
}
// auto t_infer = std::chrono::high_resolution_clock::now();
this->update_tracker(detection_results, frame.size());
this->draw_results(frame);
auto t_track_draw = std::chrono::high_resolution_clock::now();
// 3. (不变) 写入输出流
if (writer_.isOpened()) {
writer_.write(frame);
}
// auto t_write = std::chrono::high_resolution_clock::now();
// [保留] 性能日志
// double read_ms = std::chrono::duration_cast<std::chrono::duration<double, std::milli>>(t_read - t_start).count();
// double infer_ms = std::chrono::duration_cast<std::chrono::duration<double, std::milli>>(t_infer - t_read).count();
// double track_ms = std::chrono::duration_cast<std::chrono::duration<double, std::milli>>(t_track_draw - t_infer).count();
// double write_ms = std::chrono::duration_cast<std::chrono::duration<double, std::milli>>(t_write - t_track_draw).count();
// double total_ms = std::chrono::duration_cast<std::chrono::duration<double, std::milli>>(t_write - t_start).count();
// printf("Loop time: Total=%.1fms [Read=%.1f, Infer=%.1f, Track=%.1f, Write=%.1f]\n",
// total_ms, read_ms, infer_ms, track_ms, write_ms);
}
spdlog::info("VideoService: Processing loop finished.");
}
}

View File

@ -1,36 +1,35 @@
// video_service.h
// video_service.h (修改后)
#pragma once
#include <string>
#include <thread>
#include <atomic>
#include <memory>
#include <memory> // <--- 需要
#include <map>
#include <mutex>
#include <condition_variable>
#include <opencv2/core/core.hpp>
#include <opencv2/videoio.hpp>
#include "postprocess.h"
#include "algorithm/IAnalysisModule.h" // <--- 关键:包含新接口
// 向前声明
template<typename T, typename IN, typename OUT>
class rknnPool;
class rkYolov5s;
//
// !!! 关键: TrackedPerson 结构体 已被移除, 移至 IntrusionModule.h !!!
//
// 向前声明 (已不再需要)
// template<typename T, typename IN, typename OUT>
// class rknnPool;
// class rkYolov5s;
struct TrackedPerson
{
int id;
cv::Rect box;
double entry_time;
bool is_in_zone;
bool alarm_triggered;
int frames_unseen;
};
class VideoService {
public:
VideoService(std::string model_path,
int thread_num,
/**
* @brief
* model_path thread_num
* AI
*/
VideoService(std::unique_ptr<IAnalysisModule> module,
std::string input_url,
std::string output_rtsp_url);
@ -41,40 +40,27 @@ public:
private:
void processing_loop();
void reading_loop(); // [新增] 读取线程的循环函数
void reading_loop(); //
void update_tracker(detect_result_group_t &detect_result_group, const cv::Size& frame_size);
void draw_results(cv::Mat& frame); // 绘图辅助函数
void trigger_alarm(int person_id, const cv::Rect& box);
double get_current_time_seconds();
std::string model_path_;
int thread_num_;
std::unique_ptr<IAnalysisModule> module_; // <--- 关键持有AI模块的指针
std::string input_url_;
std::string output_rtsp_url_;
int frame_width_ = 0;
int frame_height_ = 0;
double frame_fps_ = 0.0;
std::unique_ptr<rknnPool<rkYolov5s, cv::Mat, detect_result_group_t>> rknn_pool_;
cv::VideoCapture capture_;
cv::VideoWriter writer_;
std::thread processing_thread_;
std::thread reading_thread_;
std::atomic<bool> running_{false};
std::mutex frame_mutex_;
std::condition_variable frame_cv_;
cv::Mat latest_frame_;
bool new_frame_available_{false};
cv::Rect intrusion_zone_;
std::map<int, TrackedPerson> tracked_persons_;
int next_track_id_;
double intrusion_time_threshold_;
std::string log_prefix_;
};

View File

@ -1,58 +1,97 @@
// video_service_manager.cc
// video_service_manager.cc (修改后)
#include "video_service_manager.h"
#include "spdlog/spdlog.h"
VideoServiceManager::~VideoServiceManager() {
// 确保在析构时停止所有服务
stop_all();
}
void VideoServiceManager::load_and_start(ConfigManager& config) {
if (!config.getIsVideoServiceEnabled()) {
if (!config.getIsVideoServiceEnabled()) { //
spdlog::warn("VideoService is disabled in configuration. No streams will be started.");
return;
}
model_path_ = config.getVideoModelPath();
if (model_path_.empty()) {
spdlog::error("Video model path is not set in configuration. Cannot start video services.");
return;
}
auto stream_configs = config.getVideoStreamConfigs();
auto stream_configs = config.getVideoStreamConfigs(); //
spdlog::info("Found {} video stream configurations.", stream_configs.size());
for (const auto& sc : stream_configs) {
if (!sc.enabled) {
if (!sc.enabled) { //
spdlog::info("Video stream '{}' (input: {}) is disabled in config, skipping.", sc.id, sc.input_url);
continue;
}
if (sc.rknn_thread_num <= 0) {
spdlog::warn("Video stream '{}' has invalid rknn_thread_num ({}). Defaulting to 1.", sc.id, sc.rknn_thread_num);
}
int threads_for_this_stream = (sc.rknn_thread_num > 0) ? sc.rknn_thread_num : 1;
//
// --- 关键:工厂逻辑 ---
//
std::unique_ptr<IAnalysisModule> module = nullptr;
// auto module_type = "instrustion_detection"; // module_type = sc.module_type
try
{
// 1. 根据配置创建 "AI模块" (策略)
if (sc.module_type == "intrusion_detection") {
std::string module_model_path = sc.module_config.value("model_path", "");
int module_threads = sc.module_config.value("rknn_thread_num", 1);
double threshold = sc.module_config.value("time_threshold_sec", 3.0);
try {
// 为每个流创建一个独立的 VideoService 实例
std::vector<int> zone_array;
if (sc.module_config.contains("intrusion_zone") && sc.module_config["intrusion_zone"].is_array()) {
// 使用 .get<T>() 将 json 数组转换为 std::vector
zone_array = sc.module_config["intrusion_zone"].get<std::vector<int>>();
}
if (module_threads <= 0) {
spdlog::warn("Video stream '{}' has invalid rknn_thread_num. Defaulting to 1.", sc.id);
module_threads = 1;
}
if (zone_array.size() != 4) {
spdlog::warn("Video stream '{}' intrusion_zone invalid. Defaulting to (0,0,0,0).", sc.id);
}
cv::Rect zone = (zone_array.size() == 4) ?
cv::Rect(zone_array[0], zone_array[1], zone_array[2], zone_array[3]) :
cv::Rect(0, 0, 0, 0); // 默认无效区
// 创建具体的入侵检测模块
module = std::make_unique<IntrusionModule>(
module_model_path,
module_threads,
zone,
threshold
);
} else if (sc.module_type == "face_recognition") {
// (未来的实现)
// std::string db_path = sc.module_config.getString("face_db_path");
// module = std::make_unique<FaceRecognitionModule>(db_path);
spdlog::warn("Module type 'face_recognition' for stream '{}' is not yet implemented.", sc.id);
continue; // 跳过此流
} else {
spdlog::error("Unknown module_type '{}' for stream '{}'. Skipping.", sc.module_type, sc.id);
continue; // 跳过此流
}
// 2. 创建 "管线" (VideoService) 并注入模块
auto service = std::make_unique<VideoService>(
model_path_,
threads_for_this_stream,
std::move(module), // <-- 依赖注入!
sc.input_url,
sc.output_rtsp
);
if (service->start()) {
spdlog::info("Successfully started video service for stream '{}' [Input: {}]. Output is [{}].",
sc.id, sc.input_url, sc.output_rtsp);
services_.push_back(std::move(service));
// 3. 启动服务 (逻辑不变)
if (service->start()) { //
spdlog::info("Successfully started video service for stream '{}' [Module: {}]. Output is [{}].",
sc.id, sc.module_type, sc.output_rtsp);
services_.push_back(std::move(service)); //
} else {
spdlog::error("Failed to start video service for stream '{}' [Input: {}].",
spdlog::error("Failed to start video service for stream '{}' [Input: {}].", //
sc.id, sc.input_url);
}
} catch (const std::exception& e) {
}
catch (const std::exception &e)
{ //
spdlog::error("Exception while creating VideoService for stream '{}' [{}]: {}",
sc.id, sc.input_url, e.what());
}
@ -61,10 +100,9 @@ void VideoServiceManager::load_and_start(ConfigManager& config) {
spdlog::info("VideoServiceManager finished setup. {} streams are now running.", services_.size());
}
void VideoServiceManager::stop_all() {
void VideoServiceManager::stop_all() {
spdlog::info("Stopping all video services ({})...", services_.size());
// 按顺序停止所有服务
for (auto& service : services_) {
if (service) {
service->stop();

View File

@ -1,8 +1,10 @@
// video_service_manager.h
// video_service_manager.h (修改后)
#pragma once
#include "config/config_manager.h" // 需要访问配置
#include "rknn/video_service.h" // 需要创建 VideoService
#include "config/config_manager.h"
#include "rknn/video_service.h" // (保持不变)
#include "algorithm/IAnalysisModule.h"
#include "algorithm/IntrusionModule.h"
#include <vector>
#include <memory>
#include <string>
@ -29,5 +31,4 @@ public:
private:
std::vector<std::unique_ptr<VideoService>> services_;
std::string model_path_;
};