bonus-edge-proxy/src/rknn/rkYolov5s.cc

414 lines
12 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include <stdio.h>
#include <mutex>
#include <chrono> // 用于计时
#include <string> // 使用 std::string
#include <vector> // 使用 std::vector
#include <algorithm> // 使用 std::min/max
#include "postprocess.h"
#include "preprocess.h"
#include "rkYolov5s.hpp"
#include "rknn/coreNum.hpp"
#include "opencv2/core/core.hpp"
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include "rknn/rknn_api.h"
// 报警接口函数 (目前只打印信息)
void trigger_alarm(int person_id, const cv::Rect& box) {
printf("[ALARM] Intrusion detected! Person ID: %d at location (%d, %d, %d, %d)\n",
person_id, box.x, box.y, box.width, box.height);
// TODO: 在这里实现真正的报警逻辑,例如发送网络消息、写入数据库等。
}
// 获取当前时间的函数 (返回秒)
double get_current_time_seconds() {
return std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now().time_since_epoch()
).count();
}
static void dump_tensor_attr(rknn_tensor_attr *attr)
{
std::string shape_str = attr->n_dims < 1 ? "" : std::to_string(attr->dims[0]);
for (int i = 1; i < attr->n_dims; ++i)
{
shape_str += ", " + std::to_string(attr->dims[i]);
}
}
static unsigned char *load_data(FILE *fp, size_t ofst, size_t sz)
{
unsigned char *data;
int ret;
data = NULL;
if (NULL == fp)
{
return NULL;
}
ret = fseek(fp, ofst, SEEK_SET);
if (ret != 0)
{
printf("blob seek failure.\n");
return NULL;
}
data = (unsigned char *)malloc(sz);
if (data == NULL)
{
printf("buffer malloc failure.\n");
return NULL;
}
ret = fread(data, 1, sz, fp);
return data;
}
static unsigned char *load_model(const char *filename, int *model_size)
{
FILE *fp;
unsigned char *data;
fp = fopen(filename, "rb");
if (NULL == fp)
{
printf("Open file %s failed.\n", filename);
return NULL;
}
fseek(fp, 0, SEEK_END);
int size = ftell(fp);
data = load_data(fp, 0, size);
fclose(fp);
*model_size = size;
return data;
}
static int saveFloat(const char *file_name, float *output, int element_size)
{
FILE *fp;
fp = fopen(file_name, "w");
for (int i = 0; i < element_size; i++)
{
fprintf(fp, "%.6f\n", output[i]);
}
fclose(fp);
return 0;
}
rkYolov5s::rkYolov5s(const std::string &model_path)
{
this->model_path = model_path;
nms_threshold = NMS_THRESH;
box_conf_threshold = BOX_THRESH;
// 初始化跟踪器和入侵检测参数
next_track_id = 1;
intrusion_time_threshold = 3.0; // 报警时间阈值3秒
// 默认设置一个无效的入侵区域,将在第一帧时根据图像大小初始化
intrusion_zone = cv::Rect(0, 0, 0, 0);
}
void rkYolov5s::set_intrusion_zone(const cv::Rect& zone) {
std::lock_guard<std::mutex> lock(mtx);
this->intrusion_zone = zone;
}
int rkYolov5s::init(rknn_context *ctx_in, bool share_weight)
{
printf("Loading model...\n");
int model_data_size = 0;
model_data = load_model(model_path.c_str(), &model_data_size);
if (share_weight == true)
ret = rknn_dup_context(ctx_in, &ctx);
else
ret = rknn_init(&ctx, model_data, model_data_size, 0, NULL);
if (ret < 0)
{
printf("rknn_init error ret=%d\n", ret);
return -1;
}
rknn_core_mask core_mask;
switch (get_core_num())
{
case 0:
core_mask = RKNN_NPU_CORE_0;
break;
case 1:
core_mask = RKNN_NPU_CORE_1;
break;
case 2:
core_mask = RKNN_NPU_CORE_2;
break;
}
ret = rknn_set_core_mask(ctx, core_mask);
if (ret < 0)
{
printf("rknn_init core error ret=%d\n", ret);
return -1;
}
rknn_sdk_version version;
ret = rknn_query(ctx, RKNN_QUERY_SDK_VERSION, &version, sizeof(rknn_sdk_version));
if (ret < 0)
{
printf("rknn_init error ret=%d\n", ret);
return -1;
}
printf("sdk version: %s driver version: %s\n", version.api_version, version.drv_version);
ret = rknn_query(ctx, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
if (ret < 0)
{
printf("rknn_init error ret=%d\n", ret);
return -1;
}
printf("model input num: %d, output num: %d\n", io_num.n_input, io_num.n_output);
input_attrs = (rknn_tensor_attr *)calloc(io_num.n_input, sizeof(rknn_tensor_attr));
for (int i = 0; i < io_num.n_input; i++)
{
input_attrs[i].index = i;
ret = rknn_query(ctx, RKNN_QUERY_INPUT_ATTR, &(input_attrs[i]), sizeof(rknn_tensor_attr));
if (ret < 0)
{
printf("rknn_init error ret=%d\n", ret);
return -1;
}
dump_tensor_attr(&(input_attrs[i]));
}
output_attrs = (rknn_tensor_attr *)calloc(io_num.n_output, sizeof(rknn_tensor_attr));
for (int i = 0; i < io_num.n_output; i++)
{
output_attrs[i].index = i;
ret = rknn_query(ctx, RKNN_QUERY_OUTPUT_ATTR, &(output_attrs[i]), sizeof(rknn_tensor_attr));
dump_tensor_attr(&(output_attrs[i]));
}
if (input_attrs[0].fmt == RKNN_TENSOR_NCHW)
{
printf("model is NCHW input fmt\n");
channel = input_attrs[0].dims[1];
height = input_attrs[0].dims[2];
width = input_attrs[0].dims[3];
}
else
{
printf("model is NHWC input fmt\n");
height = input_attrs[0].dims[1];
width = input_attrs[0].dims[2];
channel = input_attrs[0].dims[3];
}
printf("model input height=%d, width=%d, channel=%d\n", height, width, channel);
memset(inputs, 0, sizeof(inputs));
inputs[0].index = 0;
inputs[0].type = RKNN_TENSOR_UINT8;
inputs[0].size = width * height * channel;
inputs[0].fmt = RKNN_TENSOR_NHWC;
inputs[0].pass_through = 0;
return 0;
}
rknn_context *rkYolov5s::get_pctx()
{
return &ctx;
}
void rkYolov5s::update_tracker(detect_result_group_t &detect_result_group)
{
std::vector<cv::Rect> current_detections;
for (int i = 0; i < detect_result_group.count; i++) {
detect_result_t *det_result = &(detect_result_group.results[i]);
if (strcmp(det_result->name, "person") == 0) {
current_detections.push_back(cv::Rect(
det_result->box.left, det_result->box.top,
det_result->box.right - det_result->box.left,
det_result->box.bottom - det_result->box.top));
}
}
// 1. 对于已有的跟踪目标,增加其未见帧数
for (auto it = tracked_persons.begin(); it != tracked_persons.end(); ++it) {
it->second.frames_unseen++;
}
// 2. 将当前帧的检测结果与已有的跟踪目标进行匹配
for (const auto& det_box : current_detections) {
bool is_matched = false;
int best_match_id = -1;
double max_iou = 0.3; // IoU阈值用于判断是否为同一目标
for (auto const& [id, person] : tracked_persons) {
// 计算交并比 (Intersection over Union)
double iou = (double)(det_box & person.box).area() / (double)(det_box | person.box).area();
if (iou > max_iou) {
max_iou = iou;
best_match_id = id;
}
}
if (best_match_id != -1) {
// 匹配成功,更新目标信息
tracked_persons[best_match_id].box = det_box;
tracked_persons[best_match_id].frames_unseen = 0;
is_matched = true;
} else {
// 匹配失败,创建新的跟踪目标
TrackedPerson new_person;
new_person.id = next_track_id++;
new_person.box = det_box;
new_person.entry_time = 0;
new_person.is_in_zone = false;
new_person.alarm_triggered = false;
new_person.frames_unseen = 0;
tracked_persons[new_person.id] = new_person;
}
}
// 3. 处理和更新每个目标的状态
double current_time = get_current_time_seconds();
for (auto it = tracked_persons.begin(); it != tracked_persons.end(); ++it) {
TrackedPerson& person = it->second;
// 判断人员包围盒是否与入侵区域有交集
bool currently_in_zone = (intrusion_zone & person.box).area() > 0;
if (currently_in_zone) {
if (!person.is_in_zone) {
// 刚进入区域
person.is_in_zone = true;
person.entry_time = current_time;
} else {
// 已在区域内,检查是否超时
if (!person.alarm_triggered && (current_time - person.entry_time) > intrusion_time_threshold) {
person.alarm_triggered = true;
trigger_alarm(person.id, person.box);
}
}
} else {
// 不在区域内,重置状态
person.is_in_zone = false;
person.entry_time = 0;
person.alarm_triggered = false;
}
}
// 4. 移除消失太久的目标
for (auto it = tracked_persons.begin(); it != tracked_persons.end(); ) {
if (it->second.frames_unseen > 20) { // 超过20帧未见则移除
it = tracked_persons.erase(it);
} else {
++it;
}
}
}
cv::Mat rkYolov5s::infer(cv::Mat &orig_img)
{
std::lock_guard<std::mutex> lock(mtx);
cv::Mat img;
cv::cvtColor(orig_img, img, cv::COLOR_BGR2RGB);
img_width = img.cols;
img_height = img.rows;
BOX_RECT pads;
memset(&pads, 0, sizeof(BOX_RECT));
cv::Size target_size(width, height);
cv::Mat resized_img(target_size.height, target_size.width, CV_8UC3);
float scale_w = (float)target_size.width / img.cols;
float scale_h = (float)target_size.height / img.rows;
if (img_width != width || img_height != height)
{
rga_buffer_t src;
rga_buffer_t dst;
memset(&src, 0, sizeof(src));
memset(&dst, 0, sizeof(dst));
ret = resize_rga(src, dst, img, resized_img, target_size);
if (ret != 0)
{
fprintf(stderr, "resize with rga error\n");
}
inputs[0].buf = resized_img.data;
}
else
{
inputs[0].buf = img.data;
}
rknn_inputs_set(ctx, io_num.n_input, inputs);
rknn_output outputs[io_num.n_output];
memset(outputs, 0, sizeof(outputs));
for (int i = 0; i < io_num.n_output; i++)
{
outputs[i].want_float = 0;
}
ret = rknn_run(ctx, NULL);
ret = rknn_outputs_get(ctx, io_num.n_output, outputs, NULL);
detect_result_group_t detect_result_group;
std::vector<float> out_scales;
std::vector<int32_t> out_zps;
for (int i = 0; i < io_num.n_output; ++i)
{
out_scales.push_back(output_attrs[i].scale);
out_zps.push_back(output_attrs[i].zp);
}
post_process((int8_t *)outputs[0].buf, (int8_t *)outputs[1].buf, (int8_t *)outputs[2].buf, height, width,
box_conf_threshold, nms_threshold, pads, scale_w, scale_h, out_zps, out_scales, &detect_result_group);
// 更新跟踪器状态
// 首次运行时,根据图像尺寸初始化入侵区域 (设定在画面中央)
if (intrusion_zone.width == 0 || intrusion_zone.height == 0) {
intrusion_zone = cv::Rect(orig_img.cols / 4, orig_img.rows / 4, orig_img.cols / 2, orig_img.rows / 2);
}
update_tracker(detect_result_group);
// 绘制入侵区域
cv::rectangle(orig_img, intrusion_zone, cv::Scalar(255, 255, 0), 2); // 黄色
// 绘制框体和报警状态
for (auto const& [id, person] : tracked_persons) {
// 根据是否触发报警决定颜色 (BGR: 红色 vs 绿色)
cv::Scalar box_color = person.alarm_triggered ? cv::Scalar(0, 0, 255) : cv::Scalar(0, 255, 0);
int line_thickness = person.alarm_triggered ? 3 : 2;
cv::rectangle(orig_img, person.box, box_color, line_thickness);
std::string label = "Person " + std::to_string(id);
if (person.is_in_zone) {
label += " (In Zone)";
}
cv::putText(orig_img, label, cv::Point(person.box.x, person.box.y - 10),
cv::FONT_HERSHEY_SIMPLEX, 0.5, box_color, 2);
}
ret = rknn_outputs_release(ctx, io_num.n_output, outputs);
return orig_img;
}
rkYolov5s::~rkYolov5s()
{
deinitPostProcess();
ret = rknn_destroy(ctx);
if (model_data)
free(model_data);
if (input_attrs)
free(input_attrs);
if (output_attrs)
free(output_attrs);
}