Jiale/test2_ort/pipeline/roc_calcu.cpp

139 lines
4.9 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//
// created by wangjiale on 2024/5/9
// description: 对face_folder中的人脸计算sim然后将sim储存在txt文件中然后用python 处理这些sim得到AUC曲线并得到合适的阈值。
#include "lite/lite.h"
#include <fstream>
#include <iostream>
#include "dirent.h"
#include <sys/stat.h>
#include <random>
struct Node{
std::string name;
std::vector<std::vector<float>> embeddings;
int num_embd;
Node() :num_embd(0){};
};
int main(int argc, char* argv[]){
const std::string detect_onnx = R"(C:\Users\JIALE\Desktop\bns_proj\test2_ort\pipeline\hub\Pytorch_RetinaFace_resnet50.onnx)";
const std::string reg_onnx = R"(C:\Users\JIALE\Desktop\bns_proj\test2_ort\pipeline\hub\focal-arcface-bh-ir50-asia.onnx)";
if (argc <= 1)
{
std::cerr << "argc <= 1" <<std::endl;
return -1;
}
std::string face_folder = argv[1]; //R"(path/to/face_folder)";
lite::cv::face::detect::RetinaFace *retinaface = new lite::cv::face::detect::RetinaFace(detect_onnx); //default: Pytorch_RetinaFace_resnet50.onnx
lite::cv::faceid::FocalAsiaArcFace *focal_asia_arcface = new lite::cv::faceid::FocalAsiaArcFace(reg_onnx); //default: focal-arcface-bh-ir50-asia.onnx
std::vector<Node> nodes; //储存着人脸的内容
// get face embeddings from folder
if (face_folder.back() != '/' || face_folder.back() != '\\')
face_folder.push_back('/');
DIR* dir = opendir(face_folder.c_str());
if (dir == nullptr){
std::cerr << "Cannot open directory: " << face_folder << std::endl;
return -1;
}
struct dirent *outEntry, *entry;
while((outEntry=readdir(dir)) != nullptr)
{
std::string foldername = outEntry->d_name;
if(foldername.find('.') != std::string::npos)
continue;
std::string subfolder = face_folder + foldername + '/';
DIR* subdir = opendir(subfolder.c_str());
Node node;
while ((entry = readdir(subdir)) != nullptr)
{
std::string fileName = entry->d_name;
if(fileName == "." || fileName ==".." || fileName.find(".jpg") == std::string::npos)
continue;
std::string filePath = subfolder + fileName;
cv::Mat img_bgr = cv::imread(filePath);
lite::types::FaceContent known_face_content;
std::vector<lite::types::Boxf> detected_boxes;
retinaface->detect(img_bgr, detected_boxes);
if (detected_boxes.empty() || detected_boxes.size() > 1 || !detected_boxes[0].flag )
{
std::cout << filePath + ": has " + std::to_string(static_cast<int>(detected_boxes.size()));
continue;
}
focal_asia_arcface->detect(img_bgr(detected_boxes[0].rect()), known_face_content);
if (known_face_content.flag){
node.embeddings.push_back(known_face_content.embedding);
++ node.num_embd;
// node.name = fileName.substr(0,fileName.size()-4);
}
// std::cout << "File name: " << fileName << std::endl;
// std::cout << "File path: " << filePath << std::endl;
}
node.name = foldername;
nodes.push_back(node);
closedir(subdir);
}
closedir(dir);
// compute sim between faces
vector<int> flags;
vector<float> sims;
std::random_device rd;
std::mt19937 gen(rd());
int num_sel = 8;
for(int i = 0; i < nodes.size(); i++ )
{
//compute negative similarity
std::vector<float>& source_embedding = nodes[i].embeddings[0];
for(int j = 0; j < num_sel; j++ )
{
int idx = i;
while(idx == i)
idx = std::uniform_int_distribution<int>(0,nodes.size()-1)(gen);
int iidx = std::uniform_int_distribution<int>(0,nodes[idx].num_embd)(gen);
std::vector<float>& target_embedding = nodes[idx].embeddings[iidx];
flags.push_back(0);
sims.push_back(
lite::utils::math::cosine_similarity<float>(source_embedding, target_embedding)
)
}
//compute positive similarity
vector<vector<float>& embds = nodes[i].embeddings;
for(int a = 0; a < embds.size(); a++)
{
for(int b = a+1; b < embds.size(); b++)
{
flags.push_back(1);
sims.push_back(
lite::utils::math::cosine_similarity<float>(
embds[a], embds[b])
)
}
}
}
delete retinaface;
delete focal_asia_arcface;
std::ofstream simfile("simfile.txt");
std::ofstream intfile("intfile.txt");
for(float similarity : sims)
simfile << static_cast<float>(similarity) << std::endl;
for(int flag : flags)
intfile << static_cast<int>(flag) << std::endl;
simfile.close();
intfile.close();
return 0;
}