// // Created by DefTruth on 2021/9/20. // #ifndef LITE_AI_TOOLKIT_ORT_CV_RVM_H #define LITE_AI_TOOLKIT_ORT_CV_RVM_H #include "lite/ort/core/ort_core.h" namespace ortcv { class LITE_EXPORTS RobustVideoMatting { private: Ort::Env ort_env; Ort::Session *ort_session = nullptr; // CPU MemoryInfo Ort::AllocatorWithDefaultOptions allocator; Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu( OrtArenaAllocator, OrtMemTypeDefault); // hardcode input node names unsigned int num_inputs = 6; std::vector input_node_names = { "src", "r1i", "r2i", "r3i", "r4i", "downsample_ratio" }; // init dynamic input dims std::vector> dynamic_input_node_dims = { {1, 3, 1280, 720}, // src (b=1,c,h,w) {1, 1, 1, 1}, // r1i {1, 1, 1, 1}, // r2i {1, 1, 1, 1}, // r3i {1, 1, 1, 1}, // r4i {1} // downsample_ratio dsr }; // (1, 16, ?h, ?w) for inner loop rxi // hardcode output node names unsigned int num_outputs = 6; std::vector output_node_names = { "fgr", "pha", "r1o", "r2o", "r3o", "r4o" }; const LITEORT_CHAR *onnx_path = nullptr; const char *log_id = nullptr; bool context_is_update = false; // input values handler & init std::vector dynamic_src_value_handler; std::vector dynamic_r1i_value_handler = {0.0f}; // init 0. with shape (1,1,1,1) std::vector dynamic_r2i_value_handler = {0.0f}; std::vector dynamic_r3i_value_handler = {0.0f}; std::vector dynamic_r4i_value_handler = {0.0f}; std::vector dynamic_dsr_value_handler = {0.25f}; // downsample_ratio with shape (1) protected: const unsigned int num_threads; // initialize at runtime. public: explicit RobustVideoMatting(const std::string &_onnx_path, unsigned int _num_threads = 1); ~RobustVideoMatting(); protected: RobustVideoMatting(const RobustVideoMatting &) = delete; // RobustVideoMatting(RobustVideoMatting &&) = delete; // RobustVideoMatting &operator=(const RobustVideoMatting &) = delete; // RobustVideoMatting &operator=(RobustVideoMatting &&) = delete; // private: // return normalized src, rxi, dsr Tensors std::vector transform(const cv::Mat &mat); int64_t value_size_of(const std::vector &dims); // get value size void generate_matting(std::vector &output_tensors, types::MattingContent &content); void update_context(std::vector &output_tensors); public: /** * Image Matting Using RVM(https://github.com/PeterL1n/RobustVideoMatting) * @param mat: cv::Mat BGR HWC * @param content: types::MattingContent to catch the detected results. * @param downsample_ratio: 0.25 by default. * @param video_mode: false by default. * See https://github.com/PeterL1n/RobustVideoMatting/blob/master/documentation/inference_zh_Hans.md */ void detect(const cv::Mat &mat, types::MattingContent &content, float downsample_ratio = 0.25f, bool video_mode = false); /** * Video Matting Using RVM(https://github.com/PeterL1n/RobustVideoMatting) * @param video_path: eg. xxx/xxx/input.mp4 * @param output_path: eg. xxx/xxx/output.mp4 * @param contents: vector of MattingContent to catch the detected results. * @param save_contents: false by default, whether to save MattingContent. * @param downsample_ratio: 0.25 by default. * See https://github.com/PeterL1n/RobustVideoMatting/blob/master/documentation/inference_zh_Hans.md * @param writer_fps: FPS for VideoWriter, 20 by default. */ void detect_video(const std::string &video_path, const std::string &output_path, std::vector &contents, bool save_contents = false, float downsample_ratio = 0.25f, unsigned int writer_fps = 20); }; } #endif //LITE_AI_TOOLKIT_ORT_CV_RVM_H