Vehicle_Road_Counter/src/videoService/video_pipeline.cpp

255 lines
7.9 KiB
C++
Raw Normal View History

#include "video_pipeline.hpp"
#include <algorithm> // for std::max, std::min
#include <chrono>
VideoPipeline::VideoPipeline() : running_(false), next_track_id_(0) {
detector_ = std::make_unique<YoloDetector>();
// 模型路径
if (detector_->init("../models/vehicle_model.rknn") != 0) {
spdlog::error("Failed to initialize YoloDetector");
} else {
spdlog::info("YoloDetector initialized successfully.");
}
}
VideoPipeline::~VideoPipeline() {
Stop();
}
void VideoPipeline::Start(const std::string& inputUrl, const std::string& outputUrl) {
if (running_)
return;
running_ = true;
processingThread_ = std::thread(&VideoPipeline::processLoop, this, inputUrl, outputUrl, false);
}
void VideoPipeline::StartTest(const std::string& filePath, const std::string& outputUrl) {
if (running_)
return;
running_ = true;
processingThread_ = std::thread(&VideoPipeline::processLoop, this, filePath, outputUrl, true);
}
void VideoPipeline::Stop() {
if (!running_)
return;
running_ = false;
if (processingThread_.joinable())
processingThread_.join();
spdlog::info("VideoPipeline Stopped.");
}
// [新增] 计算两个矩形的交并比 (IoU)
float VideoPipeline::computeIOU(const cv::Rect& box1, const cv::Rect& box2) {
int x1 = std::max(box1.x, box2.x);
int y1 = std::max(box1.y, box2.y);
int x2 = std::min(box1.x + box1.width, box2.x + box2.width);
int y2 = std::min(box1.y + box1.height, box2.y + box2.height);
if (x1 >= x2 || y1 >= y2)
return 0.0f;
float intersection = (float)((x2 - x1) * (y2 - y1));
float area1 = (float)(box1.width * box1.height);
float area2 = (float)(box2.width * box2.height);
return intersection / (area1 + area2 - intersection);
}
// [新增] 核心逻辑:追踪与平滑
void VideoPipeline::updateTracker(const std::vector<DetectionResult>& detections) {
// 1. 标记所有现有轨迹为丢失 (missing_frames + 1)
for (auto& pair : tracks_) {
pair.second.missing_frames++;
}
// 2. 匹配当前帧检测结果与现有轨迹
for (const auto& det : detections) {
cv::Rect detBox(det.x, det.y, det.width, det.height);
int best_match_id = -1;
float max_iou = 0.0f;
// 寻找 IoU 最大的匹配
for (auto& pair : tracks_) {
float iou = computeIOU(detBox, pair.second.box);
if (iou > 0.3f && iou > max_iou) { // 阈值 0.3 可根据需要调整
max_iou = iou;
best_match_id = pair.first;
}
}
// 假设: class_id == 1 是新能源(Green), class_id == 0 是油车(Fuel)
// 请根据你模型的实际定义修改这里的 ID
float current_is_ev = (det.class_id == 1) ? 1.0f : 0.0f;
if (best_match_id != -1) {
// === 匹配成功:更新平滑分数 ===
TrackedVehicle& track = tracks_[best_match_id];
track.box = detBox;
track.missing_frames = 0; // 重置丢失计数
// [核心算法] 指数移动平均 (EMA)
// alpha = 0.1 表示新结果占10%权重历史占90%,数值越小越平滑,反应越慢
float alpha = 0.15f;
track.ev_score = track.ev_score * (1.0f - alpha) + current_is_ev * alpha;
} else {
// === 未匹配:创建新轨迹 ===
TrackedVehicle newTrack;
newTrack.id = next_track_id_++;
newTrack.box = detBox;
newTrack.missing_frames = 0;
newTrack.ev_score = current_is_ev; // 初始分数为当前检测结果
newTrack.last_class_id = det.class_id;
tracks_[newTrack.id] = newTrack;
}
}
// 3. 移除长时间丢失的轨迹 & 更新显示状态
for (auto it = tracks_.begin(); it != tracks_.end();) {
if (it->second.missing_frames > 10) { // 超过10帧未检测到则移除
it = tracks_.erase(it);
} else {
// 根据平滑后的分数决定最终类别
// 阈值 0.5: 分数 > 0.5 判定为新能源
int final_class = (it->second.ev_score > 0.5f) ? 1 : 0;
it->second.last_class_id = final_class;
it->second.label = (final_class == 1) ? "Green" : "Fuel"; // 显示文本
++it;
}
}
}
// [修改] 绘制函数使用 TrackedVehicle
void VideoPipeline::drawOverlay(cv::Mat& frame, const std::vector<TrackedVehicle>& trackedObjects) {
for (const auto& trk : trackedObjects) {
// 如果丢失了几帧但还在内存里,用虚线或者灰色表示(可选),这里保持原样
if (trk.missing_frames > 0)
continue; // 暂时只画当前帧存在的
// 颜色:新能源用绿色,油车用红色
cv::Scalar color = (trk.last_class_id == 1) ? cv::Scalar(0, 255, 0) : cv::Scalar(0, 0, 255);
cv::rectangle(frame, trk.box, color, 2);
// 显示 类别 + 平滑后的分数
// 例如: Green 0.85
std::string text = fmt::format("{} {:.2f}", trk.label, trk.ev_score);
int y_pos = trk.box.y - 5;
if (y_pos < 10)
y_pos = trk.box.y + 15;
cv::putText(frame, text, cv::Point(trk.box.x, y_pos), cv::FONT_HERSHEY_SIMPLEX, 0.6, color,
2);
}
cv::putText(frame, "RK3588 YOLOv8 Smooth", cv::Point(20, 50), cv::FONT_HERSHEY_SIMPLEX, 1.0,
cv::Scalar(0, 255, 255), 2);
}
void VideoPipeline::processLoop(std::string inputUrl, std::string outputUrl, bool isFileSource) {
cv::VideoCapture cap;
cap.open(inputUrl);
if (!cap.isOpened()) {
spdlog::error("Failed to open input: {}", inputUrl);
running_ = false;
return;
}
const double TARGET_FPS = 60.0;
const double FRAME_DURATION_MS = 1000.0 / TARGET_FPS;
int width = cap.get(cv::CAP_PROP_FRAME_WIDTH);
int height = cap.get(cv::CAP_PROP_FRAME_HEIGHT);
// GStreamer pipeline string (保持原样...)
std::stringstream pipeline;
pipeline << "appsrc ! "
<< "videoconvert ! "
<< "video/x-raw,format=NV12,width=" << width << ",height=" << height
<< ",framerate=" << (int)TARGET_FPS << "/1 ! "
<< "mpph264enc ! "
<< "h264parse ! "
<< "rtspclientsink location=" << outputUrl << " protocols=tcp";
cv::VideoWriter writer;
writer.open(pipeline.str(), cv::CAP_GSTREAMER, 0, TARGET_FPS, cv::Size(width, height), true);
cv::Mat frame;
long frame_count = 0;
using Clock = std::chrono::high_resolution_clock;
using Ms = std::chrono::duration<double, std::milli>;
while (running_) {
auto loop_start = std::chrono::steady_clock::now();
auto loop_begin_tp = Clock::now();
auto t1 = Clock::now();
if (!cap.read(frame)) {
if (isFileSource) {
cap.set(cv::CAP_PROP_POS_FRAMES, 0);
continue;
} else {
std::this_thread::sleep_for(std::chrono::seconds(1));
cap.release();
cap.open(inputUrl);
continue;
}
}
if (frame.empty())
continue;
auto t2 = Clock::now();
// 1. 推理
std::vector<DetectionResult> results = detector_->detect(frame);
auto t3 = Clock::now();
// 2. [修改] 更新追踪器和平滑分数
updateTracker(results);
auto t4 = Clock::now();
// 3. [修改] 准备绘图数据
std::vector<TrackedVehicle> tracks_to_draw;
for (const auto& pair : tracks_) {
tracks_to_draw.push_back(pair.second);
}
// 4. [修改] 绘制
drawOverlay(frame, tracks_to_draw);
auto t5 = Clock::now();
// 5. 推流
if (writer.isOpened())
writer.write(frame);
auto t6 = Clock::now();
// ----------------- 计算耗时 -----------------
double ms_read = std::chrono::duration_cast<Ms>(t2 - t1).count();
double ms_infer = std::chrono::duration_cast<Ms>(t3 - t2).count();
double ms_track = std::chrono::duration_cast<Ms>(t4 - t3).count();
double ms_draw = std::chrono::duration_cast<Ms>(t5 - t4).count();
double ms_write = std::chrono::duration_cast<Ms>(t6 - t5).count();
double ms_total = std::chrono::duration_cast<Ms>(t6 - t1).count();
frame_count++;
if (frame_count % 60 == 0 || ms_total > 18.0) {
spdlog::info(
"Frame[{}] Cost: Total={:.1f}ms | Read={:.1f} Infer={:.1f} Track={:.1f} "
"Draw={:.1f} Write={:.1f}",
frame_count, ms_total, ms_read, ms_infer, ms_track, ms_draw, ms_write);
}
// FPS控制 (保持原样)
if (isFileSource) {
auto loop_end = std::chrono::steady_clock::now();
double elapsed_ms =
std::chrono::duration<double, std::milli>(loop_end - loop_start).count();
double wait_ms = FRAME_DURATION_MS - elapsed_ms;
if (wait_ms > 0)
std::this_thread::sleep_for(std::chrono::milliseconds((int)wait_ms));
}
}
cap.release();
writer.release();
}