Jiale/FaceRegWeb/process.py

588 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# created on 2024/6/12
# modified on 2024/6/12
# description: tool file (.py)
import cv2
import os
import hashlib
import pickle
import requests
import base64
import logging
import numpy as np
from datetime import datetime
import redis
from redisLock import RedisReadWriteLock
import onnxruntime
import time
import yaml
from models import FaceRecoger, FaceBoxesV2, Landmark5er, FaceAlign, QualityOfClarity, QualityOfPose, QualityChecker
so = onnxruntime.SessionOptions()
so.log_severity_level = 3 # 0=VERBOSE, 1=INFO, 2=WARNING, 3=ERROR, 4=FATAL
# 获取workers
if "NUM_WORKERS" not in os.environ:
raise RuntimeError("Environment variable NUM_WORKERS is required but not set.")
NUM_WORKERS = int(os.getenv("NUM_WORKERS", 10))
# max readers
max_readers = int(os.getenv("MAX_READERS",60))
# 连接到 Redis
redis_host = str(os.getenv("REDIS_HOST", 'localhost'))
redis_port = int(os.getenv("REDIS_PORT", 2012))
redis_password = str(os.getenv("REDIS_PASSWORD", 'Xjsfzb@Redis123!'))
# connected
redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password, db=0)
PID_id = None
NUM_WORKERS = int(os.getenv("NUM_WORKERS", 10))
for i in range(NUM_WORKERS):
if redis_client.setnx(f"worker_{i}", 0): # 设置为dirty
PID_id = i
break
# create ReadWriteLock
rw_lock = RedisReadWriteLock(redis_client, max_readers=max_readers)
ErrorMsg = {
"101":"no face in the database",
"102":"invalid file path",
"103":"invaild file suffix",
"201":"input file type is not support, only support: base64, url, local-path",
"202":"this updateDB type is not support",
"300":"deal the image successfully.",
"301":"no face in the image.",
"302":"too many face in the image.",
"303":"bad bright of face in the image.",
"3041":"face shifted left/right, partially not captured.",
"3042":"face shifted top/bottom, partially not captured.",
"3051":"rightFace in the image.",
"3052":"leftFace in the image.",
"3053":"upFace in the image.",
"3054":"downFace in the image.",
"306":"too small resolution of face in the image.",
"307":"bad clarity of face in the image.",
"401":"identity has exist. to pretect the db, reject opt of this time ",
}
class FileError(Exception):
def __init__(self, arg:str):
self.code = arg
self.args = [f"{str(self.__class__.__name__)} {str(arg)}: {ErrorMsg[arg]}"]
class NotImpltError(Exception):
def __init__(self, arg:str):
self.code = arg
self.args = [f"{str(self.__class__.__name__)} {str(arg)}: {ErrorMsg[arg]}"]
class FaceError(Exception):
def __init__(self, arg:str):
self.code = arg
self.args = [f"{str(self.__class__.__name__)} {str(arg)}: {ErrorMsg[arg]}"]
class UpdatedbError(Exception):
def __init__(self, arg:str):
self.code = arg
self.args = [f"{str(self.__class__.__name__)} {str(arg)}: {ErrorMsg[arg]}"]
# setting Logger
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_dir = f"{os.path.dirname(os.path.abspath(__file__))}/log"
os.makedirs(log_dir, exist_ok=True)
logging.basicConfig(filename=f'{log_dir}/{current_time}.log', level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__) # @@@@
print(log_dir)
def list_images(path: str):
"""
List images in a given path
Args:
path (str): path's location
Returns:
images (list): list of exact image paths
"""
images = []
for r, _, f in os.walk(path, followlinks=True):
for file in f:
exact_path = os.path.join(r, file)
_, ext = os.path.splitext(exact_path)
ext_lower = ext.lower()
if ext_lower not in {".jpg", ".jpeg", ".png"}:
continue
images.append(exact_path)
# with Image.open(exact_path) as img: # lazy
# if img.format.lower() in ["jpeg", "png"]:
# images.append(exact_path)
return images
def find_image_hash(file_path: str) -> str:
"""
Find the hash of given image file with its properties
finding the hash of image content is costly operation
Args:
file_path (str): exact image path
Returns:
hash (str): digest with sha1 algorithm
"""
file_stats = os.stat(file_path)
# some properties
file_size = file_stats.st_size
creation_time = file_stats.st_ctime
modification_time = file_stats.st_mtime
properties = f"{file_size}-{creation_time}-{modification_time}"
hasher = hashlib.sha1()
hasher.update(properties.encode("utf-8"))
return hasher.hexdigest()
# 支持base64 local-path url 等多种检索图片的方式,返回 numpy
def load_img(img_path:str):
image = None
try:
if img_path.startswith(("http","www")): # url
response = requests.get(url=img_path, stream=True, timeout=60, proxies={"http": None, "https": None})
response.raise_for_status()
image_array = np.asarray(bytearray(response.raw.read()), dtype=np.uint8)
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
elif img_path.startswith(("./","/","C:","D:","E:",".\\")) or os.path.isfile(img_path): # local-path
if not os.path.isfile(img_path):
raise FileError("102") # push: invalid file path
elif not img_path.lower().endswith((".jpg",'.jpeg','.png')):
raise FileError("103") # push: invaild file suffix
else:
image = cv2.imread(img_path)
elif img_path.startswith("data:") and "base64" in img_path: # base64
encoded_data_parts = img_path.split(",")
if len(encoded_data_parts) <= 0:
raise FileError("104") # push: base64 is empty
print( "base64 is empty" )
encoded_data = encoded_data_parts[-1]
nparr = np.fromstring(base64.b64decode(encoded_data), np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
else:
raise NotImpltError("201") # push: input file type is not support, only support: base64, url, local-path
except Exception as e:
logger.info(f"{e}")
raise e
# return e
else:
return image
def encoder_img2base64(img:np.ndarray):
success, encoded_img = cv2.imencode('.png', img)
if success:
img_base64 = base64.b64encode(encoded_img).decode("utf-8")
return ",".join(["data:image/jpg;base64", img_base64])
# from seetaface.api import *
class FaceHelper:
def __init__(self, db_dir, config_path = './config.yaml'):
self.db_dir = os.path.abspath(db_dir)
self.pid = PID_id
self.db_embeddings = None
self.db_identities = None
# 根据config_path 读取ymal配置文件然后进行初始化
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
self.sim_threshold = config['faceReg']['sim_threshold'] # 0.7
self.rotclsifer = onnxruntime.InferenceSession( config['ck_paths']['rotifer'], so) # "./checkpoints/model_gray_mobilenetv2_rotcls.onnx"
self.db_path = os.path.join( db_dir, "seetaface6.pkl" ).lower()
self.fd = FaceBoxesV2(config['ck_paths']['FcBx'], config['ck_paths']['num_threads'] ) # r"./checkpoints/faceboxesv2-640x640.onnx" 4
self.ld5er = Landmark5er( onnx_path1 = config['ck_paths']['landmk1'], # "./checkpoints/face_landmarker_pts5_net1.onnx",
onnx_path2 = config['ck_paths']['landmk2'], # "./checkpoints/face_landmarker_pts5_net2.onnx",
num_threads=config['ck_paths']['num_threads'] # 4
)
self.fa = FaceAlign()
self.fr = FaceRecoger(onnx_path = config['ck_paths']['FcReg'], num_threads= config['ck_paths']['num_threads'] ) # "./checkpoints/face_recognizer.onnx" 4
self.qc = QualityChecker(config['brightness']['v0'], config['brightness']['v1'],
config['brightness']['v2'], config['brightness']['v3'],
hw = (config['resolution']['height'], config['resolution']['width'])
) # v0=70.0, v1=100.0, v2=210.0, v3=230.0
self.qpose = QualityOfPose(yaw_thrd=config['pose']['yaw_thrd'], pitch_thrd=config['pose']['pitch_thrd'],
var_onnx_path = config['pose']['var_onnx_path'], # './checkpoints/fsanet-var.onnx',
conv_onnx_path = config['pose']['conv_onnx_path'], # './checkpoints/fsanet-conv.onnx'
)
self.qclarity = QualityOfClarity(low_thresh=config['clarity']['low_thrd'], high_thresh=config['clarity']['high_thrd'])
# refresh the db
try:
self.updateDB(None, None, None, Onlyrefresh=True)
except Exception as e:
# raise e
pass
print(f"db_dir: {self.db_dir}; PID: {self.pid}")
logger.info(f"db_dir: {self.db_dir} ; PID: {self.pid}")
# 读操作
def faceRecognition(self, img_path:str):
rw_lock.acquire_read()
if int(redis_client.get(f"worker_{self.pid}")) == 0: # 说明self中的db和磁盘中的db不同步
with open(self.db_path, "rb") as f:
representations = pickle.load(f)
if representations == []:
self.db_embeddings, self.db_identities = None, None
else:
self.db_embeddings = np.array([rep["embedding"] for rep in representations], dtype=np.float32)
self.db_identities = [os.path.splitext(os.path.basename(rep["identity"]))[0] for rep in representations]
redis_client.set(f"worker_{self.pid}", 1) # 同步完毕
try:
if self.db_embeddings is None:
raise FileError("101") # push: no face in the database
image = load_img(img_path) # get bgr numpy image
start = time.time()
unknown_embeddings, cropped_images, names = [], [], []
image = self.rotadjust(image) # 调整角度
detect_result = self.fd.detect(image)
detect_result = [(box.x1, box.y1, box.x2, box.y2) for box in detect_result]
rett = {'code':" ", 'msg':" ", 'data':" "}
if len(detect_result) == 0:
logger.info(f"{img_path[:200]}: no face in the image")
print(f"{img_path[:200]}: no face in the image")
raise FaceError("301") # push: no face in the image
elif len(detect_result) > 1:
rett['code'] = '302'
rett['data'] = {'code':" ", 'msg':" ", 'data':" "}
for facebox in detect_result:
landmarks5 = self.ld5er.inference(image, facebox) # return: [(),(),(),(),()] 左眼 右眼 鼻子 左嘴角 右嘴角
# print("5点关键点",landmarks5)
# 输入image 和5点特征点位置(基于原图image的位置) , return all cropped aligned face (裁剪后的对齐后的人脸部分图像, 简写为aligned_faces
landmarks5 = [ [ld5[0],ld5[1]] for ld5 in landmarks5]
cropped_face = self.fa.align(image, landmarks5=landmarks5)
# 输入aligned_faces return all features of aligned_faces
feature = self.fr.inference(cropped_face)
cropped_images.append(cropped_face)
unknown_embeddings.append(feature)
unknown_embeddings = np.vstack(unknown_embeddings)
results = np.dot(unknown_embeddings, self.db_embeddings.T)
max_values = np.max(results,axis=1)
max_idxs = np.argmax(results,axis=1)
for i, (idx, value) in enumerate(zip(max_idxs, max_values)):
name = "unknown"
if value > self.sim_threshold:
name = self.db_identities[idx]
names.append(name)
ret_data = []
for i, (facebox, name) in enumerate(zip(detect_result, names)):
if name != 'unknown':
ret_data.append({'code':"300", 'msg': ErrorMsg["300"], 'data':name})
else:
code = self.check_face("None", image, facebox, prefix='facereg')
ret_data.append({'code':code, 'msg': ErrorMsg[code], 'data':name})
if len(ret_data) != 1:
ret_data = {'code':"302", 'msg': ErrorMsg["302"], 'data': ret_data}
else:
ret_data = ret_data[0]
print("facereg runtime:", time.time() - start)
except Exception as e:
logger.info(f"{e}")
rw_lock.release_read()
raise e
else:
rw_lock.release_read()
return ret_data
# return names, [ encoder_img2base64(det) for det in cropped_images]
# opt in ['add','delete','replace'] identity作为检索的标识符img_path只是提供文件路径
# 写操作
def updateDB(self, img_path :str, opt :str, identity :str, Onlyrefresh=False):
global rw_lock
rw_lock.acquire_write() # 写锁定
print("come in the updatedb")
try:
if not Onlyrefresh:
if int(redis_client.get(f"worker_{self.pid}")) == 0: # 说明self中的db和磁盘中的db不同步
with open(self.db_path, "rb") as f:
representations = pickle.load(f)
if representations == []:
self.db_embeddings, self.db_identities = None, None
else:
self.db_embeddings = np.array([rep["embedding"] for rep in representations], dtype=np.float32)
self.db_identities = [os.path.splitext(os.path.basename(rep["identity"]))[0] for rep in representations]
redis_client.set(f"worker_{self.pid}", 1) # 同步完毕
img = load_img(img_path)
img = self.rotadjust(img) # 调整角度
if opt in ["add","replace"]:
if opt == "add" and self.db_identities is not None and identity in self.db_identities:
raise UpdatedbError("401") # push: identity has exist. to pretect the db, reject opt of this time
else:
detect_result = self.fd.detect(img)
if len(detect_result) == 0: # no face
logger.info(f"{img_path[:200]}: when update, no face in the image")
print(f"{img_path[:200]}: when update, no face in the image")
raise FaceError("301") # push: no face in the image
else: # 获取最大的face,然后进行check
# H, W = img.shape[:2]
areas = [ box.area() for box in detect_result]
max_idx = areas.index(max(areas))
facebox = detect_result[max_idx]
facebox = (facebox.x1, facebox.y1, facebox.x2, facebox.y2) # top_left point, bottom_right point
FaceError_number = self.check_face(img_path=img_path[:200], img=img, facebox=facebox, prefix='update')
if FaceError_number != "300":
raise FaceError(FaceError_number)
cv2.imwrite(os.path.join(self.db_dir, identity+'.jpg'),img,[cv2.IMWRITE_JPEG_QUALITY, 100]) # 如果file已经存在则会替换它
elif opt == "delete":
try:
os.remove(os.path.join(self.db_dir, identity+'.jpg'))
except FileNotFoundError:
pass
else:
raise NotImpltError("202") # push: this updateDB type is not support
print("end the updateDB")
logger.info(f"end the updateDB")
self.refresh_database(check = Onlyrefresh) # 结束时刷新下db, 并通知别的进程dirty
except Exception as e:
logger.info(f"{e}")
rw_lock.release_write()
raise e
else:
rw_lock.release_write()
return 0
def refresh_database(self, check = True):
# ensure db exist
os.makedirs(self.db_dir, exist_ok=True)
if not os.path.exists(self.db_path):
with open(self.db_path, "wb") as f:
pickle.dump([], f)
representations = [] # representations 最后要储存在db中
# Load the representations from the pickle file
with open(self.db_path, "rb") as f:
representations = pickle.load(f)
# get identities of image
pickle_images = [rep["identity"] for rep in representations]
# get the list of images on the dir
storage_images = list_images(self.db_dir)
# transform all images in storage_images to `.jpg`
for idx in range(len(storage_images)):
img_path = storage_images[idx]
base_path, ext = os.path.splitext(img_path)
if ext == '.jpg':
continue
iimg = cv2.imread(img_path)
cv2.imwrite(base_path+'.jpg', iimg, [cv2.IMWRITE_JPEG_QUALITY, 100])
storage_images[idx] = base_path+'.jpg'
must_save_pickle = False
new_images = []; old_images = []; replaced_images = []
new_images = list(set(storage_images) - set(pickle_images))
old_images = list(set(pickle_images) - set(storage_images))
for current_representation in representations: # 找到被替换的images
identity = current_representation["identity"]
if identity in old_images:
continue
alpha_hash = current_representation["hash"]
beta_hash = find_image_hash(identity)
if alpha_hash != beta_hash:
# logger.debug(f"Even though {identity} represented before, it's replaced later.")
replaced_images.append(identity)
new_images = new_images + replaced_images
old_images = old_images + replaced_images
# remove old images first
if len(old_images) > 0:
representations = [rep for rep in representations if rep["identity"] not in old_images]
must_save_pickle = True
# find representations for new images
if len(new_images) > 0:
print("find new images")
new_representations = []
for new_image in new_images:
image = cv2.imread(new_image)
image = self.rotadjust(image) # 调整旋转角度
detect_result = self.fd.detect(image)
if len(detect_result) == 0:
logger.info(f"{new_image}: when refresh, no face in the image, delete")
print(f"{new_image}: when refresh, no face in the image, delete")
else:
if len(detect_result) > 1:
logger.info(f"{new_image}: find too many face, get and extract the biggest face in them")
else:
logger.info(f"{new_image}: find one face, perfect!")
areas = [ box.area() for box in detect_result]
max_idx = areas.index(max(areas))
facebox = detect_result[max_idx]
facebox = (facebox.x1, facebox.y1, facebox.x2, facebox.y2) # top_left point, bottom_right point
if check:
FaceError_number = self.check_face(img_path=new_image[:200], img=image, facebox=facebox, prefix='refreshdb')
if FaceError_number != "300":
continue
landmarks5 = self.ld5er.inference(image, facebox) # return: [(),(),(),(),()] 左眼 右眼 鼻子 左嘴角 右嘴角
landmarks5 = [ [ld5[0],ld5[1]] for ld5 in landmarks5]
cropped_face = self.fa.align(image, landmarks5=landmarks5)
feature = self.fr.inference(cropped_face)
new_representations.append({
"identity": new_image,
"hash": find_image_hash(new_image),
"embedding": feature,
"detected_face_base64": encoder_img2base64(cropped_face),
})
representations += new_representations
must_save_pickle = True
if must_save_pickle:
print("must save the pickle")
with open(self.db_path, "wb") as f:
pickle.dump(representations, f)
global redis_client, NUM_WORKERS
for i in range(NUM_WORKERS):
redis_client.set(f"worker_{i}", 0) # 通知别的进程db有更新
# 保证db_dir 中的图片和self.db["identity"] 一致
remove_images = list(set(storage_images) - set([rep["identity"] for rep in representations]))
for remove_img in remove_images:
try:
# os.remove(remove_img)
fname = os.path.basename(remove_img)
# os.rename( remove_img, os.path.join(self.db_dir, "..","images","remove",fname) )
except FileNotFoundError:
pass
if int(redis_client.get(f"worker_{self.pid}")) == 0:
empty = False
if len(representations) <= 0:
self.db_embeddings = None
empty = True
# raise FileError("101") # push: no face in db
else:
self.db_embeddings = np.array([rep["embedding"] for rep in representations], dtype=np.float32)
self.db_identities = [os.path.splitext(os.path.basename(rep["identity"]))[0] for rep in representations]
redis_client.set(f"worker_{self.pid}", 1) # 当前进程已更新
if empty:
logger.info("no face in the database")
raise FileError("101") # push: no face in db
def rotadjust(self, img: np.ndarray):
image = img.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 转换为灰度图像
image = cv2.resize(image, (256, 256)) # resize (256,256)
# 中心裁剪到 224x224
start = (256 - 224) // 2
image = image[start:start+224, start:start+224]
# 将单通道灰度图像转换为三通道
image = np.stack((image,)*3, axis=-1)
# 转换为符合 ONNX 需要的格式
image = image.astype(np.float32) / 255.0 # 归一化
image = image.transpose(2, 0, 1) # 将图像从 HWC 格式转换为 CHW 格式
image = np.expand_dims(image, axis=0) # 增加一个批次维度
inputs = {self.rotclsifer.get_inputs()[0].name: image}
probs = self.rotclsifer.run(None, inputs)
label = np.argmax(probs[0][0]) # 推理得到的逆时针旋转角度 [0,90,180,270]
if label == 1:
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
logger.info("img turn left, use `cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)` to repair")
print("img turn left, use `cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)` to repair")
elif label == 2:
img = cv2.rotate(img, cv2.ROTATE_180)
logger.info("img flip the image vertically, use `cv2.rotate(img, cv2.ROTATE_180)` to repair")
print("img flip the image vertically, use `cv2.rotate(img, cv2.ROTATE_180)` to repair")
elif label == 3:
img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
logger.info("img turn right, use `cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)` to repair")
print("img turn right, use `cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)` to repair")
return img
def get_feature(self, img: np.ndarray):
time.sleep(0.08)
# assert img.shape[0] == img.shape[1] and img.shape[0] == 256 and img.shape[2] == 3
img = cv2.resize( img, (256,256) )
input_feed = {}
# crop_img = cv2.resize(img,(248,248))
# crop_img = crop_img[...,::-1]
crop_img = img[4:252, 4:252, :][...,::-1] # 注意要考虑 长或宽 < 248的情况
input_data = crop_img.transpose((2, 0, 1))
# resize_img = cv2.resize(img, (248, 248))
# input_data = resize_img.transpose((2, 0, 1))
input_feed['_input_123'] = input_data.reshape((1, 3, 248, 248)).astype(np.float32)
pred_result = self.FR.run([], input_feed=input_feed)
# print(pred_result[0].shape)
# post process
# 1 sqrt feature
temp_result = np.sqrt(pred_result[0])
# 2 normalization feature
norm = temp_result / np.linalg.norm(temp_result, axis=1)
return norm.flatten()
def check_face(self, img_path, img, facebox, prefix="update"):
H, W = img.shape[:2]
if facebox[0] < 0 or facebox[2] >= W:
logger.info(f"{img_path}: when {prefix}, face shifted left/right")
print(f"{img_path}: when {prefix}, face shifted left/right")
return "3041" # face shifted left/right, partially not captured.
if facebox[1] < 0 or facebox[3] >= H:
logger.info(f"{img_path}: when {prefix}, face shifted top/bottom")
print(f"{img_path}: when {prefix}, face shifted top/bottom")
return "3042" # face shifted top/bottom, partially not captured.
face_img = img[ max(0,int(facebox[1])):int(facebox[3]), max(0,int(facebox[0])):int(facebox[2]) ]
if not self.qc.check_bright(face_img):
logger.info(f"{img_path}: when {prefix}, bad bright face in the image")
print(f"{img_path}: when {prefix}, bad bright face in the image")
return "303" # bad bright face in the image
if not self.qc.check_resolution(face_img):
logger.info(f"{img_path}: when {prefix}, too small resolution of face in the image")
print(f"{img_path}: when {prefix}, too small resolution of face in the image")
return "306" # small face in the image
pose = self.qpose.check(face_img)
if pose != "frontFace":
logger.info(f"{img_path}: when {prefix}, {pose} in the image")
print(f"{img_path}: when {prefix}, {pose} in the image")
dictt = {"rightFace": "3051", "leftFace": "3052", "upFace": "3053", "downFace": "3054"}
return dictt[pose] # pose of face in the image
if not self.qclarity.check(face_img):
logger.info(f"{img_path}: when {prefix}, bad clarity of face in the image")
print(f"{img_path}: when {prefix}, bad clarity of face in the image")
return "307" # bad clarity of face in the image
return "300"