// // Created by DefTruth on 2021/9/20. // #include "rvm.h" #include "lite/ort/core/ort_utils.h" #include "lite/utils.h" using ortcv::RobustVideoMatting; RobustVideoMatting::RobustVideoMatting(const std::string &_onnx_path, unsigned int _num_threads) : log_id(_onnx_path.data()), num_threads(_num_threads) { #ifdef LITE_WIN32 std::wstring _w_onnx_path(lite::utils::to_wstring(_onnx_path)); onnx_path = _w_onnx_path.data(); #else onnx_path = _onnx_path.data(); #endif ort_env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, log_id); // 0. session options Ort::SessionOptions session_options; session_options.SetIntraOpNumThreads(num_threads); session_options.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED); session_options.SetLogSeverityLevel(4); // 1. session // GPU Compatibility. #ifdef USE_CUDA OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0); // C API stable. #endif ort_session = new Ort::Session(ort_env, onnx_path, session_options); #if LITEORT_DEBUG std::cout << "Load " << onnx_path << " done!" << std::endl; #endif } RobustVideoMatting::~RobustVideoMatting() { if (ort_session) delete ort_session; ort_session = nullptr; } int64_t RobustVideoMatting::value_size_of(const std::vector &dims) { if (dims.empty()) return 0; int64_t value_size = 1; for (const auto &size: dims) value_size *= size; return value_size; } std::vector RobustVideoMatting::transform(const cv::Mat &mat) { cv::Mat src = mat.clone(); const unsigned int img_height = mat.rows; const unsigned int img_width = mat.cols; std::vector &src_dims = dynamic_input_node_dims.at(0); // (1,3,h,w) // update src height and width src_dims.at(2) = img_height; src_dims.at(3) = img_width; // assume that rxi's dims and value_handler was updated by last step in a while loop. std::vector &r1i_dims = dynamic_input_node_dims.at(1); // (1,?,?h,?w) std::vector &r2i_dims = dynamic_input_node_dims.at(2); // (1,?,?h,?w) std::vector &r3i_dims = dynamic_input_node_dims.at(3); // (1,?,?h,?w) std::vector &r4i_dims = dynamic_input_node_dims.at(4); // (1,?,?h,?w) std::vector &dsr_dims = dynamic_input_node_dims.at(5); // (1) int64_t src_value_size = this->value_size_of(src_dims); // (1*3*h*w) int64_t r1i_value_size = this->value_size_of(r1i_dims); // (1*?*?h*?w) int64_t r2i_value_size = this->value_size_of(r2i_dims); // (1*?*?h*?w) int64_t r3i_value_size = this->value_size_of(r3i_dims); // (1*?*?h*?w) int64_t r4i_value_size = this->value_size_of(r4i_dims); // (1*?*?h*?w) int64_t dsr_value_size = this->value_size_of(dsr_dims); // 1 dynamic_src_value_handler.resize(src_value_size); // normalize & RGB cv::cvtColor(src, src, cv::COLOR_BGR2RGB); // (h,w,3) src.convertTo(src, CV_32FC3, 1.0f / 255.0f, 0.f); // 0.~1. // convert to tensor. std::vector input_tensors; input_tensors.emplace_back(ortcv::utils::transform::create_tensor( src, src_dims, memory_info_handler, dynamic_src_value_handler, ortcv::utils::transform::CHW )); input_tensors.emplace_back(Ort::Value::CreateTensor( memory_info_handler, dynamic_r1i_value_handler.data(), r1i_value_size, r1i_dims.data(), r1i_dims.size() )); input_tensors.emplace_back(Ort::Value::CreateTensor( memory_info_handler, dynamic_r2i_value_handler.data(), r2i_value_size, r2i_dims.data(), r2i_dims.size() )); input_tensors.emplace_back(Ort::Value::CreateTensor( memory_info_handler, dynamic_r3i_value_handler.data(), r3i_value_size, r3i_dims.data(), r3i_dims.size() )); input_tensors.emplace_back(Ort::Value::CreateTensor( memory_info_handler, dynamic_r4i_value_handler.data(), r4i_value_size, r4i_dims.data(), r4i_dims.size() )); input_tensors.emplace_back(Ort::Value::CreateTensor( memory_info_handler, dynamic_dsr_value_handler.data(), dsr_value_size, dsr_dims.data(), dsr_dims.size() )); return input_tensors; } void RobustVideoMatting::detect(const cv::Mat &mat, types::MattingContent &content, float downsample_ratio, bool video_mode) { if (mat.empty()) return; // 0. set dsr at runtime. dynamic_dsr_value_handler.at(0) = downsample_ratio; // 1. make input tensors, src, rxi, dsr std::vector input_tensors = this->transform(mat); // 2. inference, fgr, pha, rxo. auto output_tensors = ort_session->Run( Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), num_inputs, output_node_names.data(), num_outputs ); // 3. generate matting this->generate_matting(output_tensors, content); // 4. update context (needed for video detection.) if (video_mode) { context_is_update = false; // init state. this->update_context(output_tensors); } } void RobustVideoMatting::detect_video(const std::string &video_path, const std::string &output_path, std::vector &contents, bool save_contents, float downsample_ratio, unsigned int writer_fps) { // 0. init video capture cv::VideoCapture video_capture(video_path); const unsigned int width = video_capture.get(cv::CAP_PROP_FRAME_WIDTH); const unsigned int height = video_capture.get(cv::CAP_PROP_FRAME_HEIGHT); const unsigned int frame_count = video_capture.get(cv::CAP_PROP_FRAME_COUNT); if (!video_capture.isOpened()) { std::cout << "Can not open video: " << video_path << "\n"; return; } // 1. init video writer cv::VideoWriter video_writer(output_path, cv::VideoWriter::fourcc('m', 'p', '4', 'v'), writer_fps, cv::Size(width, height)); if (!video_writer.isOpened()) { std::cout << "Can not open writer: " << output_path << "\n"; return; } // 2. matting loop cv::Mat mat; unsigned int i = 0; while (video_capture.read(mat)) { i += 1; types::MattingContent content; this->detect(mat, content, downsample_ratio, true); // video_mode true // 3. save contents and writing out. if (content.flag) { if (save_contents) contents.push_back(content); if (!content.merge_mat.empty()) video_writer.write(content.merge_mat); } // 4. check context states. if (!context_is_update) break; #ifdef LITEORT_DEBUG std::cout << i << "/" << frame_count << " done!" << "\n"; #endif } // 5. release video_capture.release(); video_writer.release(); } void RobustVideoMatting::generate_matting(std::vector &output_tensors, types::MattingContent &content) { Ort::Value &fgr = output_tensors.at(0); // fgr (1,3,h,w) 0.~1. Ort::Value &pha = output_tensors.at(1); // pha (1,1,h,w) 0.~1. auto fgr_dims = fgr.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape(); auto pha_dims = pha.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape(); const unsigned int height = fgr_dims.at(2); // output height const unsigned int width = fgr_dims.at(3); // output width const unsigned int channel_step = height * width; // fast assign & channel transpose(CHW->HWC). float *fgr_ptr = fgr.GetTensorMutableData(); float *pha_ptr = pha.GetTensorMutableData(); cv::Mat rmat(height, width, CV_32FC1, fgr_ptr); cv::Mat gmat(height, width, CV_32FC1, fgr_ptr + channel_step); cv::Mat bmat(height, width, CV_32FC1, fgr_ptr + 2 * channel_step); cv::Mat pmat(height, width, CV_32FC1, pha_ptr); rmat *= 255.; bmat *= 255.; gmat *= 255.; cv::Mat rest = 1. - pmat; cv::Mat mbmat = bmat.mul(pmat) + rest * 153.; cv::Mat mgmat = gmat.mul(pmat) + rest * 255.; cv::Mat mrmat = rmat.mul(pmat) + rest * 120.; std::vector fgr_channel_mats, merge_channel_mats; fgr_channel_mats.push_back(bmat); fgr_channel_mats.push_back(gmat); fgr_channel_mats.push_back(rmat); merge_channel_mats.push_back(mbmat); merge_channel_mats.push_back(mgmat); merge_channel_mats.push_back(mrmat); content.pha_mat = pmat; cv::merge(fgr_channel_mats, content.fgr_mat); cv::merge(merge_channel_mats, content.merge_mat); content.fgr_mat.convertTo(content.fgr_mat, CV_8UC3); content.merge_mat.convertTo(content.merge_mat, CV_8UC3); content.flag = true; } void RobustVideoMatting::update_context(std::vector &output_tensors) { // 0. update context for video matting. Ort::Value &r1o = output_tensors.at(2); // fgr (1,?,?h,?w) Ort::Value &r2o = output_tensors.at(3); // pha (1,?,?h,?w) Ort::Value &r3o = output_tensors.at(4); // pha (1,?,?h,?w) Ort::Value &r4o = output_tensors.at(5); // pha (1,?,?h,?w) auto r1o_dims = r1o.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape(); auto r2o_dims = r2o.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape(); auto r3o_dims = r3o.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape(); auto r4o_dims = r4o.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape(); // 1. update rxi's shape according to last rxo dynamic_input_node_dims.at(1) = r1o_dims; dynamic_input_node_dims.at(2) = r2o_dims; dynamic_input_node_dims.at(3) = r3o_dims; dynamic_input_node_dims.at(4) = r4o_dims; // 2. update rxi's value according to last rxo int64_t new_r1i_value_size = this->value_size_of(r1o_dims); // (1*?*?h*?w) int64_t new_r2i_value_size = this->value_size_of(r2o_dims); // (1*?*?h*?w) int64_t new_r3i_value_size = this->value_size_of(r3o_dims); // (1*?*?h*?w) int64_t new_r4i_value_size = this->value_size_of(r4o_dims); // (1*?*?h*?w) dynamic_r1i_value_handler.resize(new_r1i_value_size); dynamic_r2i_value_handler.resize(new_r2i_value_size); dynamic_r3i_value_handler.resize(new_r3i_value_size); dynamic_r4i_value_handler.resize(new_r4i_value_size); float *new_r1i_value_ptr = r1o.GetTensorMutableData(); float *new_r2i_value_ptr = r2o.GetTensorMutableData(); float *new_r3i_value_ptr = r3o.GetTensorMutableData(); float *new_r4i_value_ptr = r4o.GetTensorMutableData(); std::memcpy(dynamic_r1i_value_handler.data(), new_r1i_value_ptr, new_r1i_value_size * sizeof(float)); std::memcpy(dynamic_r2i_value_handler.data(), new_r2i_value_ptr, new_r2i_value_size * sizeof(float)); std::memcpy(dynamic_r3i_value_handler.data(), new_r3i_value_ptr, new_r3i_value_size * sizeof(float)); std::memcpy(dynamic_r4i_value_handler.data(), new_r4i_value_ptr, new_r4i_value_size * sizeof(float)); context_is_update = true; }