Jiale/FaceRegWeb5.2/FaceRegWeb/models/facealign.py

185 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
这部分输入face5点位置和原图
输出aligned cropped face
"""
import numpy as np
import cv2
import time
import os
# image_data: src image
# image_width, image_height,image_channels: width and height, channels of src image
# src_x, src_y: 输出image每个像素对应的src image 中的像素位置.
def sampling(image_data, image_width, image_height, image_channels, src_x, src_y):
ux = np.floor(src_x).astype(int)
uy = np.floor(src_y).astype(int)
# 创建一个与src_x形状相同的空数组用于存储最终的像素值
pixel = np.zeros((*src_x.shape, image_channels), dtype=np.uint8)
# 创建一个掩码数组,标记有效的采样点
valid_mask = (ux >= 0) & (ux < image_height - 1) & (uy >= 0) & (uy < image_width - 1)
# 计算插值
x = src_x - ux
y = src_y - uy
# 提取图像数据的各个通道
image_data_reshape = image_data.reshape(-1, image_channels) # (height * width, channels)
ux_uy = ux * image_width + uy # (height * width)
ux_uy_next = ux_uy + 1
ux_next = (ux + 1) * image_width + uy
ux_next_next = ux_next + 1
ux_uy[~valid_mask] = 0
ux_uy_next[~valid_mask] = 0
ux_next[~valid_mask] = 0
ux_next_next[~valid_mask] = 0
# 使用广播计算各个通道的插值
top_left = image_data_reshape[ux_uy]
top_right = image_data_reshape[ux_uy_next]
bottom_left = image_data_reshape[ux_next]
bottom_right = image_data_reshape[ux_next_next]
# 计算插值
interpolated_top = (1 - y[:, :, np.newaxis]) * top_left + y[:, :, np.newaxis] * top_right
interpolated_bottom = (1 - y[:, :, np.newaxis]) * bottom_left + y[:, :, np.newaxis] * bottom_right
interpolated_pixel = (1 - x[:, :, np.newaxis]) * interpolated_top + x[:, :, np.newaxis] * interpolated_bottom
# 填充最终的像素值
pixel[valid_mask] = np.clip(interpolated_pixel[valid_mask], 0, 255).astype(np.uint8)
return pixel
def spatial_transform(image_data, image_width, image_height, image_channels,
crop_data, crop_width, crop_height, transformation,
pad_top=0, pad_bottom=0, pad_left=0, pad_right=0,
type='LINEAR', dtype='ZERO_PADDING', N=1):
channels = image_channels
dst_h = crop_height + pad_top + pad_bottom
dst_w = crop_width + pad_left + pad_right
for n in range(N):
theta_data = transformation.reshape(-1)
scale = np.sqrt(theta_data[0] ** 2 + theta_data[3] ** 2)
bx, by = np.meshgrid(np.arange(dst_w) - pad_left, np.arange(dst_h) - pad_top)
bx = bx.T
by = by.T
src_y = theta_data[0] * by + theta_data[1] * bx + theta_data[2]
src_x = theta_data[3] * by + theta_data[4] * bx + theta_data[5]
crop_data[:] = sampling(image_data, image_width, image_height, image_channels, src_x, src_y,)
return True
def transformation_maker(crop_width, crop_height, points, mean_shape, mean_shape_width, mean_shape_height):
points_num = len(points) # point 个数 5
std_points = np.zeros((points_num, 2), dtype=np.float32) # 标准点
# 生成标准点的坐标
for i in range(points_num):
std_points[i, 0] = mean_shape[i * 2] * crop_width / mean_shape_width
std_points[i, 1] = mean_shape[i * 2 + 1] * crop_height / mean_shape_height
feat_points = np.array(points, dtype=np.float32).reshape(points_num, 2)
# 初始化
sum_x = 0.0
sum_y = 0.0
sum_u = 0.0
sum_v = 0.0
sum_xx_yy = 0.0
sum_ux_vy = 0.0
sum_vx_uy = 0.0
for c in range(points_num):
sum_x += std_points[c, 0]
sum_y += std_points[c, 1]
sum_u += feat_points[c, 0]
sum_v += feat_points[c, 1]
sum_xx_yy += std_points[c, 0] ** 2 + std_points[c, 1] ** 2
sum_ux_vy += std_points[c, 0] * feat_points[c, 0] + std_points[c, 1] * feat_points[c, 1]
sum_vx_uy += feat_points[c, 1] * std_points[c, 0] - feat_points[c, 0] * std_points[c, 1]
if sum_xx_yy <= np.finfo(np.float32).eps:
return False, None
q = sum_u - sum_x * sum_ux_vy / sum_xx_yy + sum_y * sum_vx_uy / sum_xx_yy
p = sum_v - sum_y * sum_ux_vy / sum_xx_yy - sum_x * sum_vx_uy / sum_xx_yy
r = points_num - (sum_x ** 2 + sum_y ** 2) / sum_xx_yy
if np.abs(r) <= np.finfo(np.float32).eps:
return False, None
a = (sum_ux_vy - sum_x * q / r - sum_y * p / r) / sum_xx_yy
b = (sum_vx_uy + sum_y * q / r - sum_x * p / r) / sum_xx_yy
c = q / r
d = p / r
transformation = np.zeros((2, 3), dtype=np.float64)
transformation[0, 0] = transformation[1, 1] = a
transformation[0, 1] = -b
transformation[1, 0] = b
transformation[0, 2] = c
transformation[1, 2] = d
return True, transformation
class FaceAlign:
def __init__(self) -> None:
self.crop_width, self.crop_height = 256, 256
self.mean_shape_width, self.mean_shape_height = 256, 256
self.mean_face = [ # 标准人脸的特征点的位置
89.3095, 72.9025,
169.3095, 72.9025,
127.8949, 127.0441,
96.8796, 184.8907,
159.1065, 184.7601
]
# landmarks5 = [
# [268.99814285714285, 166.26619999999997],
# [342.636625, 164.43359999999998],
# [311.5448214285714, 221.24419999999998],
# [272.2709642857143, 243.23539999999997],
# [344.2730357142857, 241.40279999999996]
# ]
def align(self, image, landmarks5): # 原图image landmarks5
success, transformation = transformation_maker(self.crop_width, self.crop_height, landmarks5, self.mean_face, self.mean_shape_width, self.mean_shape_height)
if not success:
print("Failed to compute transformation matrix.")
img_height, img_width, img_channels = image.shape
crop_data = np.zeros((self.crop_height, self.crop_width, 3), dtype=np.uint8)
success = spatial_transform(image, img_width, img_height, img_channels,
crop_data, self.crop_width, self.crop_height,
transformation,
)
if success:
if os.path.exists("./images/result1.jpg"):
cv2.imwrite("./images/result2.jpg", crop_data, [cv2.IMWRITE_JPEG_QUALITY, 100])
else:
cv2.imwrite("./images/result1.jpg", crop_data, [cv2.IMWRITE_JPEG_QUALITY, 100])
else:
print("error when spatial_transform...")
return crop_data
if __name__ == "__main__":
fa = FaceAlign()
landmarks5 = [(240.56920098163752, 111.91879640513824),
(283.7146242409017, 93.30582481805237),
(268.9820406889578, 129.202270021718),
(259.51109411985107, 155.79222943184064),
(296.34255299971073, 137.17925784475477)]
landmarks5 = [ [ld5[0],ld5[1]] for ld5 in landmarks5]
image = cv2.imread("/home/bns/seetaface6Python/seetaFace6Python/asserts/1.jpg")
fa.align(image = image, landmarks5=landmarks5)