55 lines
2.0 KiB
Python
55 lines
2.0 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
输入:aligned cropped face
|
||
输出:face feature
|
||
注意: 提供计算两个feature的相似度的函数
|
||
"""
|
||
|
||
import cv2
|
||
import onnxruntime as ort
|
||
import numpy as np
|
||
|
||
# 设置ONNX Runtime的日志级别为ERROR
|
||
ort.set_default_logger_severity(3) # 3表示ERROR级别
|
||
|
||
class FaceRecoger():
|
||
def __init__(self, onnx_path, num_threads=1) -> None:
|
||
session_options = ort.SessionOptions()
|
||
session_options.intra_op_num_threads = num_threads
|
||
# 初始化 InferenceSession 时传入 SessionOptions 对象
|
||
self.ort_session = ort.InferenceSession(onnx_path, session_options=session_options)
|
||
|
||
output_node_dims = [out.shape for out in self.ort_session.get_outputs()]
|
||
self.len_feat = output_node_dims[0][1] # feature 的长度为...
|
||
|
||
def inference(self, crop_img): # crop_img = cv2.imread(img_path) bgr
|
||
input_feed = {}
|
||
if crop_img.shape[:2] != (248,248): # 这里还有另一种方式 ,[4:252, 4:252,...]
|
||
crop_img = cv2.resize(crop_img,(248,248))
|
||
crop_img = crop_img[...,::-1]
|
||
input_data = crop_img.transpose((2, 0, 1))
|
||
input_feed['_input_123'] = input_data.reshape((1, 3, 248, 248)).astype(np.float32)
|
||
pred_result = self.ort_session.run([], input_feed=input_feed)
|
||
temp_result = np.sqrt(pred_result[0])
|
||
norm = temp_result / np.linalg.norm(temp_result, axis=1)
|
||
return norm.flatten() # return normalize feature
|
||
|
||
@staticmethod
|
||
def compute_sim(feat1,feat2):
|
||
feat1, feat2 = feat1.flatten(), feat2.flatten()
|
||
assert feat1.shape == feat2.shape
|
||
sim = np.sum(feat1 * feat2)
|
||
return sim
|
||
|
||
|
||
if __name__ == "__main__":
|
||
fr = FaceRecoger(onnx_path = "./checkpoints/face_recognizer.onnx", num_threads=1)
|
||
import sys
|
||
imgpath1 = sys.argv[1]
|
||
imgpath2 = sys.argv[2]
|
||
img1,img2 = cv2.imread(imgpath1),cv2.imread(imgpath2)
|
||
feat1 = fr.inference(img1)
|
||
feat2 = fr.inference(img2)
|
||
|
||
print("sim: ", FaceRecoger.compute_sim(feat1, feat2))
|