Vehicle_Road_Counter/src/videoService/video_pipeline.cpp

307 lines
9.3 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "video_pipeline.hpp"
#include <algorithm> // for std::max, std::min
#include <chrono>
const int YoloDetector::NPU_CORE_CNT;
VideoPipeline::VideoPipeline() : running_(false), next_track_id_(0) {
detector_ = std::make_unique<YoloDetector>();
// 模型路径
if (detector_->init("../models/vehicle_model.rknn") != 0) {
spdlog::error("Failed to initialize YoloDetector");
} else {
spdlog::info("YoloDetector initialized successfully.");
}
}
VideoPipeline::~VideoPipeline() {
Stop();
}
void VideoPipeline::Start(const std::string& inputUrl, const std::string& outputUrl) {
if (running_)
return;
running_ = true;
processingThread_ = std::thread(&VideoPipeline::processLoop, this, inputUrl, outputUrl, false);
}
void VideoPipeline::StartTest(const std::string& filePath, const std::string& outputUrl) {
if (running_)
return;
running_ = true;
processingThread_ = std::thread(&VideoPipeline::processLoop, this, filePath, outputUrl, true);
}
void VideoPipeline::Stop() {
if (!running_)
return;
running_ = false;
if (processingThread_.joinable())
processingThread_.join();
spdlog::info("VideoPipeline Stopped.");
}
void VideoPipeline::inferenceWorker() {
while (running_) {
FrameData data;
{
std::unique_lock<std::mutex> lock(input_mtx_);
// 等待有数据或者停止信号
input_cv_.wait(lock, [this] { return !input_queue_.empty() || !running_; });
if (!running_ && input_queue_.empty())
return;
data = input_queue_.front();
input_queue_.pop();
// 通知主线程队列有空位了(流控)
input_cv_.notify_all();
}
// === 核心:这里是并行执行的 ===
// YoloDetector::detect 现在是线程安全的,会自动分配 NPU 核心
data.results = detector_->detect(data.original_frame);
// 将结果放入输出缓冲区
{
std::lock_guard<std::mutex> lock(output_mtx_);
// 使用 move 减少拷贝
output_buffer_[data.frame_id] = std::move(data);
}
// 通知主线程有结果了
output_cv_.notify_one();
}
}
// [新增] 计算两个矩形的交并比 (IoU)
float VideoPipeline::computeIOU(const cv::Rect& box1, const cv::Rect& box2) {
int x1 = std::max(box1.x, box2.x);
int y1 = std::max(box1.y, box2.y);
int x2 = std::min(box1.x + box1.width, box2.x + box2.width);
int y2 = std::min(box1.y + box1.height, box2.y + box2.height);
if (x1 >= x2 || y1 >= y2)
return 0.0f;
float intersection = (float)((x2 - x1) * (y2 - y1));
float area1 = (float)(box1.width * box1.height);
float area2 = (float)(box2.width * box2.height);
return intersection / (area1 + area2 - intersection);
}
// [新增] 核心逻辑:追踪与平滑
void VideoPipeline::updateTracker(const std::vector<DetectionResult>& detections) {
// 1. 标记所有现有轨迹为丢失 (missing_frames + 1)
for (auto& pair : tracks_) {
pair.second.missing_frames++;
}
// 2. 匹配当前帧检测结果与现有轨迹
for (const auto& det : detections) {
cv::Rect detBox(det.x, det.y, det.width, det.height);
int best_match_id = -1;
float max_iou = 0.0f;
// 寻找 IoU 最大的匹配
for (auto& pair : tracks_) {
float iou = computeIOU(detBox, pair.second.box);
if (iou > 0.3f && iou > max_iou) { // 阈值 0.3 可根据需要调整
max_iou = iou;
best_match_id = pair.first;
}
}
// 假设: class_id == 1 是新能源(Green), class_id == 0 是油车(Fuel)
// 请根据你模型的实际定义修改这里的 ID
float current_is_ev = (det.class_id == 1) ? 1.0f : 0.0f;
if (best_match_id != -1) {
// === 匹配成功:更新平滑分数 ===
TrackedVehicle& track = tracks_[best_match_id];
track.box = detBox;
track.missing_frames = 0; // 重置丢失计数
// [核心算法] 指数移动平均 (EMA)
// alpha = 0.1 表示新结果占10%权重历史占90%,数值越小越平滑,反应越慢
float alpha = 0.15f;
track.ev_score = track.ev_score * (1.0f - alpha) + current_is_ev * alpha;
} else {
// === 未匹配:创建新轨迹 ===
TrackedVehicle newTrack;
newTrack.id = next_track_id_++;
newTrack.box = detBox;
newTrack.missing_frames = 0;
newTrack.ev_score = current_is_ev; // 初始分数为当前检测结果
newTrack.last_class_id = det.class_id;
tracks_[newTrack.id] = newTrack;
}
}
// 3. 移除长时间丢失的轨迹 & 更新显示状态
for (auto it = tracks_.begin(); it != tracks_.end();) {
if (it->second.missing_frames > 10) { // 超过10帧未检测到则移除
it = tracks_.erase(it);
} else {
// 根据平滑后的分数决定最终类别
// 阈值 0.5: 分数 > 0.5 判定为新能源
int final_class = (it->second.ev_score > 0.5f) ? 1 : 0;
it->second.last_class_id = final_class;
it->second.label = (final_class == 1) ? "Green" : "Fuel"; // 显示文本
++it;
}
}
}
// [修改] 绘制函数使用 TrackedVehicle
void VideoPipeline::drawOverlay(cv::Mat& frame, const std::vector<TrackedVehicle>& trackedObjects) {
for (const auto& trk : trackedObjects) {
// 如果丢失了几帧但还在内存里,用虚线或者灰色表示(可选),这里保持原样
if (trk.missing_frames > 0)
continue; // 暂时只画当前帧存在的
// 颜色:新能源用绿色,油车用红色
cv::Scalar color = (trk.last_class_id == 1) ? cv::Scalar(0, 255, 0) : cv::Scalar(0, 0, 255);
cv::rectangle(frame, trk.box, color, 2);
// 显示 类别 + 平滑后的分数
// 例如: Green 0.85
std::string text = fmt::format("{} {:.2f}", trk.label, trk.ev_score);
int y_pos = trk.box.y - 5;
if (y_pos < 10)
y_pos = trk.box.y + 15;
cv::putText(frame, text, cv::Point(trk.box.x, y_pos), cv::FONT_HERSHEY_SIMPLEX, 0.6, color,
2);
}
cv::putText(frame, "RK3588 YOLOv8 Smooth", cv::Point(20, 50), cv::FONT_HERSHEY_SIMPLEX, 1.0,
cv::Scalar(0, 255, 255), 2);
}
void VideoPipeline::processLoop(std::string inputUrl, std::string outputUrl, bool isFileSource) {
cv::VideoCapture cap;
cap.open(inputUrl);
if (!cap.isOpened()) {
spdlog::error("Failed to open input: {}", inputUrl);
running_ = false;
return;
}
// 1. 启动 3 个工作线程 (对应 3 个 NPU 核心)
// 你的 NPU 利用率将在这里被填满
for (int i = 0; i < 3; ++i) {
worker_threads_.emplace_back(&VideoPipeline::inferenceWorker, this);
}
int width = cap.get(cv::CAP_PROP_FRAME_WIDTH);
int height = cap.get(cv::CAP_PROP_FRAME_HEIGHT);
const double TARGET_FPS = 30.0; // 提升目标 FPS
// ... GStreamer pipeline string 设置保持不变 ...
std::stringstream pipeline;
pipeline << "appsrc ! "
<< "videoconvert ! "
<< "video/x-raw,format=NV12,width=" << width << ",height=" << height
<< ",framerate=" << (int)TARGET_FPS << "/1 ! "
<< "mpph264enc ! "
<< "h264parse ! "
<< "rtspclientsink location=" << outputUrl << " protocols=tcp";
cv::VideoWriter writer;
writer.open(pipeline.str(), cv::CAP_GSTREAMER, 0, TARGET_FPS, cv::Size(width, height), true);
long read_frame_idx = 0; // 读取计数
long write_frame_idx = 0; // 写入计数
while (running_) {
// === 阶段 A: 读取并分发任务 (生产者) ===
// 限制预读数量,防止内存爆满 (例如最多预读 5 帧)
{
std::unique_lock<std::mutex> lock(input_mtx_);
// 如果输入队列满了,等待工作线程处理
if (input_queue_.size() < MAX_INPUT_QUEUE_SIZE) {
cv::Mat frame;
if (cap.read(frame) && !frame.empty()) {
FrameData data;
data.frame_id = read_frame_idx++;
data.original_frame = frame; // 拷贝一份 (必须,因为 cv::Mat 是引用计数)
input_queue_.push(data);
input_cv_.notify_one(); // 唤醒一个工作线程
} else {
if (isFileSource) {
// 文件读完了的处理...
running_ = false;
} else {
// 网络流断线的处理...
}
}
}
}
// === 阶段 B: 按顺序收集结果并处理 (消费者) ===
FrameData current_data;
bool has_data = false;
{
std::unique_lock<std::mutex> lock(output_mtx_);
// 检查输出缓冲区里是否有我们期待的下一帧 (write_frame_idx)
// 因为多线程处理第5帧可能比第4帧先处理完必须等待第4帧
auto it = output_buffer_.find(write_frame_idx);
if (it != output_buffer_.end()) {
current_data = std::move(it->second);
output_buffer_.erase(it);
has_data = true;
write_frame_idx++;
}
}
if (has_data) {
// 注意updateTracker 和 drawOverlay 必须在主线程串行执行
// 因为它们依赖 tracks_ 状态,且必须按时间顺序更新
// 1. 追踪 (CPU)
updateTracker(current_data.results);
// 2. 准备绘图数据 (CPU)
std::vector<TrackedVehicle> tracks_to_draw;
for (const auto& pair : tracks_) {
tracks_to_draw.push_back(pair.second);
}
// 3. 绘图 (CPU)
drawOverlay(current_data.original_frame, tracks_to_draw);
// 4. 推流 (IO)
if (writer.isOpened()) {
writer.write(current_data.original_frame);
}
// 简单的 FPS 打印
if (write_frame_idx % 60 == 0) {
spdlog::info("Processed Frame ID: {}", write_frame_idx);
}
} else {
// 如果没有等到当前帧,稍微休眠一下避免死循环占满 CPU
// 但不要睡太久,否则延迟高
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
// 清理:通知线程退出并 Join
input_cv_.notify_all();
for (auto& t : worker_threads_) {
if (t.joinable())
t.join();
}
worker_threads_.clear();
cap.release();
writer.release();
}