Face_reg_app/FaceFeatureExtractorAPI/models/facerecoger.py

55 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
输入aligned cropped face
输出face feature
注意: 提供计算两个feature的相似度的函数
"""
import cv2
import onnxruntime as ort
import numpy as np
# 设置ONNX Runtime的日志级别为ERROR
ort.set_default_logger_severity(3) # 3表示ERROR级别
class FaceRecoger():
def __init__(self, onnx_path, num_threads=1) -> None:
session_options = ort.SessionOptions()
session_options.intra_op_num_threads = num_threads
# 初始化 InferenceSession 时传入 SessionOptions 对象
self.ort_session = ort.InferenceSession(onnx_path, session_options=session_options)
output_node_dims = [out.shape for out in self.ort_session.get_outputs()]
self.len_feat = output_node_dims[0][1] # feature 的长度为...
def inference(self, crop_img): # crop_img = cv2.imread(img_path) bgr
input_feed = {}
if crop_img.shape[:2] != (248,248): # 这里还有另一种方式 ,[4:252, 4:252,...]
crop_img = cv2.resize(crop_img,(248,248))
crop_img = crop_img[...,::-1]
input_data = crop_img.transpose((2, 0, 1))
input_feed['_input_123'] = input_data.reshape((1, 3, 248, 248)).astype(np.float32)
pred_result = self.ort_session.run([], input_feed=input_feed)
temp_result = np.sqrt(pred_result[0])
norm = temp_result / np.linalg.norm(temp_result, axis=1)
return norm.flatten() # return normalize feature
@staticmethod
def compute_sim(feat1,feat2):
feat1, feat2 = feat1.flatten(), feat2.flatten()
assert feat1.shape == feat2.shape
sim = np.sum(feat1 * feat2)
return sim
if __name__ == "__main__":
fr = FaceRecoger(onnx_path = "./checkpoints/face_recognizer.onnx", num_threads=1)
import sys
imgpath1 = sys.argv[1]
imgpath2 = sys.argv[2]
img1,img2 = cv2.imread(imgpath1),cv2.imread(imgpath2)
feat1 = fr.inference(img1)
feat2 = fr.inference(img2)
print("sim: ", FaceRecoger.compute_sim(feat1, feat2))