Jiale/test2_ort/pipeline/faceDetectReg.cpp

80 lines
3.2 KiB
C++

//
// created by wangjiale on 2024/5/9
//
#include "lite/lite.h"
void Pathsplit(std::string str, const char split, std::vector<std::string>& res)
{
if (str == "") return;
std::replace(str.begin(),str.end(),'\\','/');
size_t pos = str.find(split);
// 若找不到内容则字符串搜索函数返回 npos
while (pos != str.npos)
{
std::string temp = str.substr(0, pos);
res.push_back(temp);
//去掉已分割的字符串,在剩下的字符串中进行分割
str = str.substr(pos + 1, str.size());
pos = str.find(split);
}
}
static void faceReg(const std::string& detect_onnx, const std::string& reg_onnx, const std::string& test_img_path1, const std::string& test_img_path2){
lite::cv::face::detect::RetinaFace *retinaface = new lite::cv::face::detect::RetinaFace(detect_onnx); //default: Pytorch_RetinaFace_resnet50.onnx
lite::cv::faceid::FocalAsiaArcFace *focal_asia_arcface = new lite::cv::faceid::FocalAsiaArcFace(reg_onnx); //default: focal-arcface-bh-ir50-asia.onnx
lite::types::FaceContent known_face_content;
lite::types::FaceContent unknown_face_content;
std::vector<lite::types::Boxf> detected_boxes1;
std::vector<cv::Mat> bgr_faces;
cv::Mat img_bgr1 = cv::imread(test_img_path1);
retinaface->detect(img_bgr1, detected_boxes1);
if (detected_boxes1.empty() || detected_boxes1.size() > 1 || !detected_boxes1[0].flag ){
throw "known_img have no/many face";
}
cv::Mat known_bgr_face = img_bgr1(detected_boxes1[0].rect());
focal_asia_arcface->detect(known_bgr_face, known_face_content);
cv::Mat img_bgr2 = cv::imread(test_img_path2);
std::vector<lite::types::Boxf> detected_boxes2;
std::vector<float> sims;
retinaface->detect(img_bgr2, detected_boxes2);
for (const auto &box: detected_boxes2)
{
if (box.flag)
{
cv::Mat cropped_img_bgr = img_bgr2(box.rect());
bgr_faces.push_back(cropped_img_bgr);
focal_asia_arcface->detect(cropped_img_bgr, unknown_face_content);
float sim = lite::utils::math::cosine_similarity<float>(
known_face_content.embedding, unknown_face_content.embedding) ;
cv::rectangle(img_bgr2, box.rect(), cv::Scalar(255,255,0),2);
cv::putText(img_bgr2, std::to_string(sim).substr(0,5), box.tl(), cv::FONT_HERSHEY_SIMPLEX, 0.6f, cv::Scalar(0,255,0), 2);
}
}
std::string target_img_path = test_img_path2;
size_t pos = target_img_path.find("sources");
target_img_path.replace(target_img_path.begin() + pos, target_img_path.begin() + pos + 7, "log");
cv::imwrite(target_img_path, img_bgr2);
delete retinaface;
delete focal_asia_arcface;
}
int main(){
const std::string detect_onnx = R"(C:\Users\JIALE\Desktop\bns_proj\test2_ort\pipeline\hub\Pytorch_RetinaFace_resnet50.onnx)";
const std::string reg_onnx = R"(C:\Users\JIALE\Desktop\bns_proj\test2_ort\pipeline\hub\focal-arcface-bh-ir50-asia.onnx)";
const std::string test_img_path1 = R"(C:\Users\JIALE\Desktop\bns_proj\test2_ort\pipeline\sources\jiangyongqi.jpg)";
const std::string test_img_path2 = R"(C:\Users\JIALE\Desktop\bns_proj\test2_ort\pipeline\sources\wangjiale.jpg)";
faceReg(detect_onnx, reg_onnx, test_img_path1, test_img_path2);
return 0;
}