// // Created by DefTruth on 2021/12/5. // #include "mg_matting.h" #include "lite/ort/core/ort_utils.h" #include "lite/utils.h" using ortcv::MGMatting; MGMatting::MGMatting(const std::string &_onnx_path, unsigned int _num_threads) : log_id(_onnx_path.data()), num_threads(_num_threads) { #ifdef LITE_WIN32 std::wstring _w_onnx_path(lite::utils::to_wstring(_onnx_path)); onnx_path = _w_onnx_path.data(); #else onnx_path = _onnx_path.data(); #endif ort_env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, log_id); // 0. session options Ort::SessionOptions session_options; session_options.SetIntraOpNumThreads(num_threads); session_options.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED); session_options.SetLogSeverityLevel(4); // 1. session // GPU Compatibility. #ifdef USE_CUDA OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0); // C API stable. #endif ort_session = new Ort::Session(ort_env, onnx_path, session_options); // 2. input name & input dims dynamic_image_values_handler.resize(dynamic_input_image_size); dynamic_mask_values_handler.resize(dynamic_input_mask_size); #if LITEORT_DEBUG this->print_debug_string(); #endif } MGMatting::~MGMatting() { if (ort_session) delete ort_session; ort_session = nullptr; } void MGMatting::print_debug_string() { std::cout << "LITEORT_DEBUG LogId: " << log_id << "\n"; std::cout << "=============== Inputs ==============\n"; std::cout << "Dynamic Input: " << "image" << " Init [" << 1 << "," << 3 << "," << dynamic_input_height << "," << dynamic_input_width << "]\n"; std::cout << "Dynamic Input: " << "mask" << " Init [" << 1 << "," << 1 << "," << dynamic_input_height << "," << dynamic_input_width << "]\n"; std::cout << "=============== Outputs ==============\n"; for (unsigned int i = 0; i < num_outputs; ++i) std::cout << "Dynamic Output " << i << ": " << output_node_names[i] << std::endl; } std::vector MGMatting::transform(const cv::Mat &mat, const cv::Mat &mask) { auto padded_mat = this->padding(mat); // 0-255 int8 h+2*align_val w+2*align_val auto padded_mask = this->padding(mask); // 0-1.0 float32 h+2*align_val w+2*align_val cv::Mat canvas; cv::cvtColor(padded_mat, canvas, cv::COLOR_BGR2RGB); canvas.convertTo(canvas, CV_32FC3, 1.f / 255.f, 0.f); // (0.,1.) ortcv::utils::transform::normalize_inplace(canvas, mean_vals, scale_vals); // convert to tensor. std::vector input_tensors; input_tensors.emplace_back(ortcv::utils::transform::create_tensor( canvas, dynamic_input_image_dims, memory_info_handler, dynamic_image_values_handler, ortcv::utils::transform::CHW )); // image 1x3xhxw input_tensors.emplace_back(ortcv::utils::transform::create_tensor( padded_mask, dynamic_input_mask_dims, memory_info_handler, dynamic_mask_values_handler, ortcv::utils::transform::CHW )); // mask 1x1xhxw return input_tensors; } cv::Mat MGMatting::padding(const cv::Mat &unpad_mat) { const unsigned int h = unpad_mat.rows; const unsigned int w = unpad_mat.cols; // aligned if (h % align_val == 0 && w % align_val == 0) { unsigned int target_h = h + 2 * align_val; unsigned int target_w = w + 2 * align_val; cv::Mat pad_mat(target_h, target_w, unpad_mat.type()); cv::copyMakeBorder(unpad_mat, pad_mat, align_val, align_val, align_val, align_val, cv::BORDER_REFLECT); return pad_mat; } // un-aligned else { // align & padding unsigned int align_h = align_val * ((h - 1) / align_val + 1); unsigned int align_w = align_val * ((w - 1) / align_val + 1); unsigned int pad_h = align_h - h; // >= 0 unsigned int pad_w = align_w - w; // >= 0 unsigned int target_h = h + align_val + (pad_h + align_val); unsigned int target_w = w + align_val + (pad_w + align_val); cv::Mat pad_mat(target_h, target_w, unpad_mat.type()); cv::copyMakeBorder(unpad_mat, pad_mat, align_val, pad_h + align_val, align_val, pad_w + align_val, cv::BORDER_REFLECT); return pad_mat; } } void MGMatting::update_guidance_mask(cv::Mat &mask, unsigned int guidance_threshold) { if (mask.type() != CV_32FC1) mask.convertTo(mask, CV_32FC1); const unsigned int h = mask.rows; const unsigned int w = mask.cols; if (mask.isContinuous()) { const unsigned int data_size = h * w * 1; float *mutable_data_ptr = (float *) mask.data; float guidance_threshold_ = (float) guidance_threshold; for (unsigned int i = 0; i < data_size; ++i) { if (mutable_data_ptr[i] >= guidance_threshold_) mutable_data_ptr[i] = 1.0f; else mutable_data_ptr[i] = 0.0f; } } // else { float guidance_threshold_ = (float) guidance_threshold; for (unsigned int i = 0; i < h; ++i) { float *p = mask.ptr(i); for (unsigned int j = 0; j < w; ++j) { if (p[j] >= guidance_threshold_) p[j] = 1.0; else p[j] = 0.; } } } } void MGMatting::detect(const cv::Mat &mat, cv::Mat &mask, types::MattingContent &content, bool remove_noise, unsigned int guidance_threshold) { if (mat.empty() || mask.empty()) return; const unsigned int img_height = mat.rows; const unsigned int img_width = mat.cols; this->update_dynamic_shape(img_height, img_width); this->update_guidance_mask(mask, guidance_threshold); // -> float32 hw1 0~1.0 // 1. make input tensors, image, mask std::vector input_tensors = this->transform(mat, mask); // 2. inference, fgr, pha, rxo. auto output_tensors = ort_session->Run( Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), num_inputs, output_node_names.data(), num_outputs ); // 3. generate matting this->generate_matting(output_tensors, mat, content, remove_noise); } void MGMatting::generate_matting(std::vector &output_tensors, const cv::Mat &mat, types::MattingContent &content, bool remove_noise) { Ort::Value &alpha_os1 = output_tensors.at(0); // (1,1,h+?,w+?) Ort::Value &alpha_os4 = output_tensors.at(1); // (1,1,h+?,w+?) Ort::Value &alpha_os8 = output_tensors.at(2); // (1,1,h+?,w+?) const unsigned int h = mat.rows; const unsigned int w = mat.cols; // https://github.com/yucornetto/MGMatting/blob/main/code-base/infer.py auto output_dims = alpha_os1.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape(); const unsigned int out_h = output_dims.at(2); const unsigned int out_w = output_dims.at(3); float *alpha_os1_ptr = alpha_os1.GetTensorMutableData(); float *alpha_os4_ptr = alpha_os4.GetTensorMutableData(); float *alpha_os8_ptr = alpha_os8.GetTensorMutableData(); cv::Mat alpha_os1_pred(out_h, out_w, CV_32FC1, alpha_os1_ptr); cv::Mat alpha_os4_pred(out_h, out_w, CV_32FC1, alpha_os4_ptr); cv::Mat alpha_os8_pred(out_h, out_w, CV_32FC1, alpha_os8_ptr); cv::Mat alpha_pred(out_h, out_w, CV_32FC1, alpha_os8_ptr); cv::Mat weight_os4 = this->get_unknown_tensor_from_pred(alpha_pred, 30); this->update_alpha_pred(alpha_pred, weight_os4, alpha_os4_pred); cv::Mat weight_os1 = this->get_unknown_tensor_from_pred(alpha_pred, 15); this->update_alpha_pred(alpha_pred, weight_os1, alpha_os1_pred); // post process if (remove_noise) this->remove_small_connected_area(alpha_pred); cv::Mat mat_copy; mat.convertTo(mat_copy, CV_32FC3); cv::Mat pmat = alpha_pred(cv::Rect(align_val, align_val, w, h)); std::vector mat_channels; cv::split(mat_copy, mat_channels); cv::Mat bmat = mat_channels.at(0); cv::Mat gmat = mat_channels.at(1); cv::Mat rmat = mat_channels.at(2); // ref only, zero-copy. bmat = bmat.mul(pmat); gmat = gmat.mul(pmat); rmat = rmat.mul(pmat); cv::Mat rest = 1.f - pmat; cv::Mat mbmat = bmat.mul(pmat) + rest * 153.f; cv::Mat mgmat = gmat.mul(pmat) + rest * 255.f; cv::Mat mrmat = rmat.mul(pmat) + rest * 120.f; std::vector fgr_channel_mats, merge_channel_mats; fgr_channel_mats.push_back(bmat); fgr_channel_mats.push_back(gmat); fgr_channel_mats.push_back(rmat); merge_channel_mats.push_back(mbmat); merge_channel_mats.push_back(mgmat); merge_channel_mats.push_back(mrmat); content.pha_mat = pmat; cv::merge(fgr_channel_mats, content.fgr_mat); cv::merge(merge_channel_mats, content.merge_mat); content.fgr_mat.convertTo(content.fgr_mat, CV_8UC3); content.merge_mat.convertTo(content.merge_mat, CV_8UC3); content.flag = true; } // https://github.com/yucornetto/MGMatting/issues/11 // https://github.com/yucornetto/MGMatting/blob/main/code-base/utils/util.py#L225 cv::Mat MGMatting::get_unknown_tensor_from_pred(const cv::Mat &alpha_pred, unsigned int rand_width) { const unsigned int h = alpha_pred.rows; const unsigned int w = alpha_pred.cols; const unsigned int data_size = h * w; cv::Mat uncertain_area(h, w, CV_32FC1, cv::Scalar(1.0f)); // continuous const float *pred_ptr = (float *) alpha_pred.data; float *uncertain_ptr = (float *) uncertain_area.data; // threshold if (alpha_pred.isContinuous() && uncertain_area.isContinuous()) { for (unsigned int i = 0; i < data_size; ++i) if ((pred_ptr[i] < 1.0f / 255.0f) || (pred_ptr[i] > 1.0f - 1.0f / 255.0f)) uncertain_ptr[i] = 0.f; } // else { for (unsigned int i = 0; i < h; ++i) { const float *pred_row_ptr = alpha_pred.ptr(i); float *uncertain_row_ptr = uncertain_area.ptr(i); for (unsigned int j = 0; j < w; ++j) { if ((pred_row_ptr[j] < 1.0f / 255.0f) || (pred_row_ptr[j] > 1.0f - 1.0f / 255.0f)) uncertain_row_ptr[j] = 0.f; } } } // dilate unsigned int size = rand_width / 2; auto kernel = cv::getStructuringElement(cv::MORPH_ELLIPSE, cv::Size(size, size)); cv::dilate(uncertain_area, uncertain_area, kernel); // weight cv::Mat weight(h, w, CV_32FC1, uncertain_area.data); // ref only, zero copy. float *weight_ptr = (float *) weight.data; if (weight.isContinuous()) { for (unsigned int i = 0; i < data_size; ++i) if (weight_ptr[i] != 1.0f) weight_ptr[i] = 0; } // else { for (unsigned int i = 0; i < h; ++i) { float *weight_row_ptr = weight.ptr(i); for (unsigned int j = 0; j < w; ++j) if (weight_row_ptr[j] != 1.0f) weight_row_ptr[j] = 0.f; } } return weight; } void MGMatting::update_alpha_pred(cv::Mat &alpha_pred, const cv::Mat &weight, const cv::Mat &other_alpha_pred) { const unsigned int h = alpha_pred.rows; const unsigned int w = alpha_pred.cols; const unsigned int data_size = h * w; const float *weight_ptr = (float *) weight.data; float *mutable_alpha_ptr = (float *) alpha_pred.data; const float *other_alpha_ptr = (float *) other_alpha_pred.data; if (alpha_pred.isContinuous() && weight.isContinuous() && other_alpha_pred.isContinuous()) { for (unsigned int i = 0; i < data_size; ++i) if (weight_ptr[i] > 0.f) mutable_alpha_ptr[i] = other_alpha_ptr[i]; } // else { for (unsigned int i = 0; i < h; ++i) { const float *weight_row_ptr = weight.ptr(i); float *mutable_alpha_row_ptr = alpha_pred.ptr(i); const float *other_alpha_row_ptr = other_alpha_pred.ptr(i); for (unsigned int j = 0; j < w; ++j) if (weight_row_ptr[j] > 0.f) mutable_alpha_row_ptr[j] = other_alpha_row_ptr[j]; } } } // https://github.com/yucornetto/MGMatting/blob/main/code-base/utils/util.py#L208 void MGMatting::remove_small_connected_area(cv::Mat &alpha_pred) { cv::Mat gray, binary; alpha_pred.convertTo(gray, CV_8UC1, 255.f); // 255 * 0.05 ~ 13 // https://github.com/yucornetto/MGMatting/blob/main/code-base/utils/util.py#L209 cv::threshold(gray, binary, 13, 255, cv::THRESH_BINARY); // morphologyEx with OPEN operation to remove noise first. auto kernel = cv::getStructuringElement(cv::MORPH_ELLIPSE, cv::Size(3, 3), cv::Point(-1, -1)); cv::morphologyEx(binary, binary, cv::MORPH_OPEN, kernel); // Computationally connected domain cv::Mat labels = cv::Mat::zeros(alpha_pred.size(), CV_32S); cv::Mat stats, centroids; int num_labels = cv::connectedComponentsWithStats(binary, labels, stats, centroids, 8, 4); if (num_labels <= 1) return; // no noise, skip. // find max connected area, 0 is background int max_connected_id = 1; // 1,2,... int max_connected_area = stats.at(max_connected_id, cv::CC_STAT_AREA); for (int i = 1; i < num_labels; ++i) { int tmp_connected_area = stats.at(i, cv::CC_STAT_AREA); if (tmp_connected_area > max_connected_area) { max_connected_area = tmp_connected_area; max_connected_id = i; } } const int h = alpha_pred.rows; const int w = alpha_pred.cols; // remove small connected area. for (int i = 0; i < h; ++i) { int *label_row_ptr = labels.ptr(i); float *alpha_row_ptr = alpha_pred.ptr(i); for (int j = 0; j < w; ++j) { if (label_row_ptr[j] != max_connected_id) alpha_row_ptr[j] = 0.f; } } } void MGMatting::update_dynamic_shape(unsigned int img_height, unsigned int img_width) { unsigned int h = img_height; unsigned int w = img_width; // update dynamic input dims if (h % align_val == 0 && w % align_val == 0) { // aligned dynamic_input_height = h + 2 * align_val; dynamic_input_width = w + 2 * align_val; } // un-aligned else { // align first unsigned int align_h = align_val * ((h - 1) / align_val + 1); unsigned int align_w = align_val * ((w - 1) / align_val + 1); unsigned int pad_h = align_h - h; // >= 0 unsigned int pad_w = align_w - w; // >= 0 dynamic_input_height = h + align_val + (pad_h + align_val); dynamic_input_width = w + align_val + (pad_w + align_val); } // update dims info dynamic_input_image_dims.at(2) = dynamic_input_height; dynamic_input_image_dims.at(3) = dynamic_input_width; dynamic_input_mask_dims.at(2) = dynamic_input_height; dynamic_input_mask_dims.at(3) = dynamic_input_width; dynamic_input_image_size = 1 * 3 * dynamic_input_height * dynamic_input_width; dynamic_input_mask_size = 1 * 1 * dynamic_input_height * dynamic_input_width; dynamic_image_values_handler.resize(dynamic_input_image_size); dynamic_mask_values_handler.resize(dynamic_input_mask_size); }