修改代码与文档

This commit is contained in:
jiang 2024-09-12 17:14:01 +08:00
parent 7c63e11132
commit 3a1c18c29d
5 changed files with 424 additions and 233 deletions

View File

@ -2,7 +2,7 @@ import redis
import os
# Redis config
redis_host = str(os.getenv("REDIS_HOST", 'localhost'))
redis_host = str(os.getenv("REDIS_HOST", '192.168.0.14'))
redis_port = int(os.getenv("REDIS_PORT", 2012))
redis_password = str(os.getenv("REDIS_PASSWORD", 'Xjsfzb@Redis123!'))
num_workers = int(os.getenv("NUM_WORKERS", 10))

View File

@ -2,52 +2,51 @@
# modified on 2024/6/12
# description: tool file (.py)
import cv2
import os
import os
import hashlib
import pickle
import requests
import base64
import logging
import numpy as np
import logging
import numpy as np
from datetime import datetime
import redis
from redisLock import RedisReadWriteLock
import onnxruntime
import time
import onnxruntime
import time
import yaml
from models import FaceRecoger, FaceBoxesV2, Landmark5er, FaceAlign, QualityOfClarity, QualityOfPose, QualityChecker
so = onnxruntime.SessionOptions()
so.log_severity_level = 3 # 0=VERBOSE, 1=INFO, 2=WARNING, 3=ERROR, 4=FATAL
# 获取workers
if "NUM_WORKERS" not in os.environ:
raise RuntimeError("Environment variable NUM_WORKERS is required but not set.")
NUM_WORKERS = int(os.getenv("NUM_WORKERS", 10))
# # 获取workers
# if "NUM_WORKERS" not in os.environ:
# raise RuntimeError("Environment variable NUM_WORKERS is required but not set.")
NUM_WORKERS = int(os.getenv("NUM_WORKERS", 10))
# max readers
max_readers = int(os.getenv("MAX_READERS",60))
max_readers = int(os.getenv("MAX_READERS", 60))
# 连接到 Redis
redis_host = str(os.getenv("REDIS_HOST", 'localhost'))
redis_host = str(os.getenv("REDIS_HOST", '192.168.0.14'))
redis_port = int(os.getenv("REDIS_PORT", 2012))
redis_password = str(os.getenv("REDIS_PASSWORD", 'Xjsfzb@Redis123!'))
redis_password = str(os.getenv("REDIS_PASSWORD", 'Xjsfzb@Redis123!'))
# connected
redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password, db=0)
PID_id = None
PID_id = None
NUM_WORKERS = int(os.getenv("NUM_WORKERS", 10))
for i in range(NUM_WORKERS):
if redis_client.setnx(f"worker_{i}", 0): # 设置为dirty
PID_id = i
if redis_client.setnx(f"worker_{i}", 0): # 设置为dirty
PID_id = i
break
# create ReadWriteLock
rw_lock = RedisReadWriteLock(redis_client, max_readers=max_readers)
ErrorMsg = {
"30003": "no face found in the database",
"30004": "invalid file path",
"30005": "invalid file suffix",
@ -69,39 +68,45 @@ ErrorMsg = {
"30015": "face rotated downwards in the image",
"30016": "low resolution of a face in the image",
"30017": "poor clarity of a face in the image",
"30019": "identity already exists; for database protection, operation rejected at this time"
}
class FileError(Exception):
def __init__(self, arg:str):
def __init__(self, arg: str):
self.code = arg
self.args = [f"{str(self.__class__.__name__)} {str(arg)}: {ErrorMsg[arg]}"]
class NotImpltError(Exception):
def __init__(self, arg:str):
def __init__(self, arg: str):
self.code = arg
self.args = [f"{str(self.__class__.__name__)} {str(arg)}: {ErrorMsg[arg]}"]
class FaceError(Exception):
def __init__(self, arg:str):
def __init__(self, arg: str):
self.code = arg
self.args = [f"{str(self.__class__.__name__)} {str(arg)}: {ErrorMsg[arg]}"]
class UpdatedbError(Exception):
def __init__(self, arg:str):
def __init__(self, arg: str):
self.code = arg
self.args = [f"{str(self.__class__.__name__)} {str(arg)}: {ErrorMsg[arg]}"]
# setting Logger
# setting Logger
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_dir = f"{os.path.dirname(os.path.abspath(__file__))}/log"
os.makedirs(log_dir, exist_ok=True)
logging.basicConfig(filename=f'{log_dir}/{current_time}.log', level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.basicConfig(filename=f'{log_dir}/{current_time}.log', level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__) # @@@@
print(log_dir)
def list_images(path: str):
"""
List images in a given path
@ -150,27 +155,28 @@ def find_image_hash(file_path: str) -> str:
hasher.update(properties.encode("utf-8"))
return hasher.hexdigest()
# 支持base64 local-path url 等多种检索图片的方式,返回 numpy
def load_img(img_path:str):
image = None
try:
if img_path.startswith(("http","www")): # url
def load_img(img_path: str):
image = None
try:
if img_path.startswith(("http", "www")): # url
response = requests.get(url=img_path, stream=True, timeout=60, proxies={"http": None, "https": None})
response.raise_for_status()
image_array = np.asarray(bytearray(response.raw.read()), dtype=np.uint8)
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
elif img_path.startswith(("./","/","C:","D:","E:",".\\")) or os.path.isfile(img_path): # local-path
elif img_path.startswith(("./", "/", "C:", "D:", "E:", ".\\")) or os.path.isfile(img_path): # local-path
if not os.path.isfile(img_path):
raise FileError("30004") # push: invalid file path
elif not img_path.lower().endswith((".jpg",'.jpeg','.png')):
raise FileError("30005") # push: invaild file suffix
elif not img_path.lower().endswith((".jpg", '.jpeg', '.png')):
raise FileError("30005") # push: invaild file suffix
else:
image = cv2.imread(img_path)
elif img_path.startswith("data:") and "base64" in img_path: # base64
elif img_path.startswith("data:") and "base64" in img_path: # base64
encoded_data_parts = img_path.split(",")
if len(encoded_data_parts) <= 0:
raise FileError("104") # push: base64 is empty
print( "base64 is empty" )
print("base64 is empty")
encoded_data = encoded_data_parts[-1]
nparr = np.fromstring(base64.b64decode(encoded_data), np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
@ -178,13 +184,13 @@ def load_img(img_path:str):
raise NotImpltError("30006") # push: input file type is not support, only support: base64, url, local-path
except Exception as e:
logger.info(f"{e}")
raise e
raise e
# return e
else:
else:
return image
def encoder_img2base64(img:np.ndarray):
def encoder_img2base64(img: np.ndarray):
success, encoded_img = cv2.imencode('.png', img)
if success:
img_base64 = base64.b64encode(encoded_img).decode("utf-8")
@ -195,41 +201,47 @@ def encoder_img2base64(img:np.ndarray):
# from seetaface.api import *
class FaceHelper:
def __init__(self, db_dir, config_path = './config.yaml'):
def __init__(self, db_dir, config_path='./config.yaml'):
self.db_dir = os.path.abspath(db_dir)
self.pid = PID_id
self.db_embeddings = None
self.db_identities = None
self.db_embeddings = None
self.db_identities = None
# 根据config_path 读取ymal配置文件然后进行初始化
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
self.sim_threshold = config['faceReg']['sim_threshold'] # 0.7
self.rotclsifer = onnxruntime.InferenceSession( config['ck_paths']['rotifer'], so) # "./checkpoints/model_gray_mobilenetv2_rotcls.onnx"
self.db_path = os.path.join( db_dir, "seetaface6.pkl" ).lower()
self.sim_threshold = config['faceReg']['sim_threshold'] # 0.7
self.rotclsifer = onnxruntime.InferenceSession(config['ck_paths']['rotifer'],
so) # "./checkpoints/model_gray_mobilenetv2_rotcls.onnx"
self.db_path = os.path.join(db_dir, "seetaface6.pkl").lower()
self.fd = FaceBoxesV2(config['ck_paths']['FcBx'], config['ck_paths']['num_threads'] ) # r"./checkpoints/faceboxesv2-640x640.onnx" 4
self.ld5er = Landmark5er( onnx_path1 = config['ck_paths']['landmk1'], # "./checkpoints/face_landmarker_pts5_net1.onnx",
onnx_path2 = config['ck_paths']['landmk2'], # "./checkpoints/face_landmarker_pts5_net2.onnx",
num_threads=config['ck_paths']['num_threads'] # 4
)
self.fd = FaceBoxesV2(config['ck_paths']['FcBx'],
config['ck_paths']['num_threads']) # r"./checkpoints/faceboxesv2-640x640.onnx" 4
self.ld5er = Landmark5er(onnx_path1=config['ck_paths']['landmk1'],
# "./checkpoints/face_landmarker_pts5_net1.onnx",
onnx_path2=config['ck_paths']['landmk2'],
# "./checkpoints/face_landmarker_pts5_net2.onnx",
num_threads=config['ck_paths']['num_threads'] # 4
)
self.fa = FaceAlign()
self.fr = FaceRecoger(onnx_path = config['ck_paths']['FcReg'], num_threads= config['ck_paths']['num_threads'] ) # "./checkpoints/face_recognizer.onnx" 4
self.qc = QualityChecker(config['brightness']['v0'], config['brightness']['v1'],
self.fr = FaceRecoger(onnx_path=config['ck_paths']['FcReg'],
num_threads=config['ck_paths']['num_threads']) # "./checkpoints/face_recognizer.onnx" 4
self.qc = QualityChecker(config['brightness']['v0'], config['brightness']['v1'],
config['brightness']['v2'], config['brightness']['v3'],
hw = (config['resolution']['height'], config['resolution']['width'])
) # v0=70.0, v1=100.0, v2=210.0, v3=230.0
self.qpose = QualityOfPose(yaw_thrd=config['pose']['yaw_thrd'], pitch_thrd=config['pose']['pitch_thrd'],
var_onnx_path = config['pose']['var_onnx_path'], # './checkpoints/fsanet-var.onnx',
conv_onnx_path = config['pose']['conv_onnx_path'], # './checkpoints/fsanet-conv.onnx'
hw=(config['resolution']['height'], config['resolution']['width'])
) # v0=70.0, v1=100.0, v2=210.0, v3=230.0
self.qpose = QualityOfPose(yaw_thrd=config['pose']['yaw_thrd'], pitch_thrd=config['pose']['pitch_thrd'],
var_onnx_path=config['pose']['var_onnx_path'], # './checkpoints/fsanet-var.onnx',
conv_onnx_path=config['pose']['conv_onnx_path'], # './checkpoints/fsanet-conv.onnx'
)
self.qclarity = QualityOfClarity(low_thresh=config['clarity']['low_thrd'], high_thresh=config['clarity']['high_thrd'])
self.qclarity = QualityOfClarity(low_thresh=config['clarity']['low_thrd'],
high_thresh=config['clarity']['high_thrd'])
# refresh the db
try:
try:
self.updateDB(None, None, None, Onlyrefresh=True)
except Exception as e:
except Exception as e:
# raise e
pass
@ -237,49 +249,49 @@ class FaceHelper:
logger.info(f"db_dir: {self.db_dir} ; PID: {self.pid}")
# 读操作
def faceRecognition(self, img_path:str):
def faceRecognition(self, img_path: str):
rw_lock.acquire_read()
if int(redis_client.get(f"worker_{self.pid}")) == 0: # 说明self中的db和磁盘中的db不同步
with open(self.db_path, "rb") as f:
representations = pickle.load(f)
if representations == []:
self.db_embeddings, self.db_identities = None, None
if representations == []:
self.db_embeddings, self.db_identities = None, None
else:
self.db_embeddings = np.array([rep["embedding"] for rep in representations], dtype=np.float32)
self.db_embeddings = np.array([rep["embedding"] for rep in representations], dtype=np.float32)
self.db_identities = [os.path.splitext(os.path.basename(rep["identity"]))[0] for rep in representations]
redis_client.set(f"worker_{self.pid}", 1) # 同步完毕
redis_client.set(f"worker_{self.pid}", 1) # 同步完毕
try:
if self.db_embeddings is None:
try:
if self.db_embeddings is None:
raise FileError("30003") # push: no face in the database
image = load_img(img_path) # get bgr numpy image
image = load_img(img_path) # get bgr numpy image
start = time.time()
unknown_embeddings, cropped_images, names = [], [], []
image = self.rotadjust(image) # 调整角度
image = self.rotadjust(image) # 调整角度
detect_result = self.fd.detect(image)
detect_result = [(box.x1, box.y1, box.x2, box.y2) for box in detect_result]
if len(detect_result) == 0:
logger.info(f"{img_path[:200]}: no face in the image")
print(f"{img_path[:200]}: no face in the image")
raise FaceError("30018") # push: no face in the image
raise FaceError("30018") # push: no face in the image
for facebox in detect_result:
landmarks5 = self.ld5er.inference(image, facebox) # return: [(),(),(),(),()] 左眼 右眼 鼻子 左嘴角 右嘴角
landmarks5 = self.ld5er.inference(image, facebox) # return: [(),(),(),(),()] 左眼 右眼 鼻子 左嘴角 右嘴角
# print("5点关键点",landmarks5)
# 输入image 和5点特征点位置(基于原图image的位置) , return all cropped aligned face (裁剪后的对齐后的人脸部分图像, 简写为aligned_faces
landmarks5 = [ [ld5[0],ld5[1]] for ld5 in landmarks5]
landmarks5 = [[ld5[0], ld5[1]] for ld5 in landmarks5]
cropped_face = self.fa.align(image, landmarks5=landmarks5)
# 输入aligned_faces return all features of aligned_faces
feature = self.fr.inference(cropped_face)
cropped_images.append(cropped_face)
unknown_embeddings.append(feature)
unknown_embeddings = np.vstack(unknown_embeddings)
results = np.dot(unknown_embeddings, self.db_embeddings.T)
max_values = np.max(results,axis=1)
max_idxs = np.argmax(results,axis=1)
max_values = np.max(results, axis=1)
max_idxs = np.argmax(results, axis=1)
for i, (idx, value) in enumerate(zip(max_idxs, max_values)):
name = "unknown"
@ -290,19 +302,19 @@ class FaceHelper:
ret_data = []
for i, (facebox, name) in enumerate(zip(detect_result, names)):
if name != 'unknown':
ret_data.append({'code':"30000", 'msg': ErrorMsg["30000"], 'data':[name,facebox]})
ret_data.append({'code': "30000", 'msg': ErrorMsg["30000"], 'data': [name, facebox]})
else:
code = self.check_face("None", image, facebox, prefix='facereg')
if code == "30002":
ret_data.append({'code':"30001", 'msg': ErrorMsg["30001"], 'data':[name,facebox]})
ret_data.append({'code': "30001", 'msg': ErrorMsg["30001"], 'data': [name, facebox]})
else:
ret_data.append({'code':code, 'msg': ErrorMsg[code], 'data':[name,facebox]})
ret_data.append({'code': code, 'msg': ErrorMsg[code], 'data': [name, facebox]})
if len(ret_data) != 1:
ret_data = {'code':"30008", 'msg': ErrorMsg["30008"], 'data': ret_data}
ret_data = {'code': "30008", 'msg': ErrorMsg["30008"], 'data': ret_data}
else:
ret_data = ret_data[0]
print("facereg runtime:", time.time() - start)
print("facereg runtime:", time.time() - start)
except Exception as e:
logger.info(f"{e}")
@ -314,25 +326,23 @@ class FaceHelper:
# return names, [ encoder_img2base64(det) for det in cropped_images]
# 只是提取人脸特征,不要对别的有任何影响
def featureDetect(self, img_path:str):
def featureDetect(self, img_path: str):
rw_lock.acquire_read()
try:
image = load_img(img_path) # get bgr numpy image
try:
image = load_img(img_path) # get bgr numpy image
unknown_embeddings = []
image = self.rotadjust(image) # 调整角度
image = self.rotadjust(image) # 调整角度
detect_result = self.fd.detect(image)
detect_result = [(box.x1, box.y1, box.x2, box.y2) for box in detect_result]
if len(detect_result) == 0:
logger.info(f"{img_path[:200]}: no face in the image")
print(f"{img_path[:200]}: no face in the image")
raise FaceError("30018") # push: no face in the image
raise FaceError("30018") # push: no face in the image
for facebox in detect_result:
landmarks5 = self.ld5er.inference(image, facebox) # return: [(),(),(),(),()] 左眼 右眼 鼻子 左嘴角 右嘴角
landmarks5 = self.ld5er.inference(image, facebox) # return: [(),(),(),(),()] 左眼 右眼 鼻子 左嘴角 右嘴角
# print("5点关键点",landmarks5)
# 输入image 和5点特征点位置(基于原图image的位置) , return all cropped aligned face (裁剪后的对齐后的人脸部分图像, 简写为aligned_faces
landmarks5 = [ [ld5[0],ld5[1]] for ld5 in landmarks5]
landmarks5 = [[ld5[0], ld5[1]] for ld5 in landmarks5]
cropped_face = self.fa.align(image, landmarks5=landmarks5)
# 输入aligned_faces return all features of aligned_faces
feature = self.fr.inference(cropped_face)
@ -341,10 +351,10 @@ class FaceHelper:
ret_data = []
for _, (facebox, feature) in enumerate(zip(detect_result, unknown_embeddings)):
code = self.check_face("None", image, facebox, prefix='featuredetect')
ret_data.append({'code':code, 'msg': ErrorMsg[code], 'data': [f"{feature.tobytes()},{str(feature.dtype)}", facebox]})
ret_data.append({'code': code, 'msg': ErrorMsg[code], 'data': [{str(num) for num in feature}, facebox]})
if len(ret_data) != 1:
ret_data = {'code':"30008", 'msg': ErrorMsg["30008"], 'data': ret_data}
ret_data = {'code': "30008", 'msg': ErrorMsg["30008"], 'data': ret_data}
else:
ret_data = ret_data[0]
@ -358,73 +368,77 @@ class FaceHelper:
# opt in ['add','delete','replace'] identity作为检索的标识符img_path只是提供文件路径
# 写操作
def updateDB(self, img_path :str, opt :str, identity :str, Onlyrefresh=False):
def updateDB(self, img_path: str, opt: str, identity: str, Onlyrefresh=False):
global rw_lock
rw_lock.acquire_write() # 写锁定
rw_lock.acquire_write() # 写锁定
print("come in the updatedb")
try:
try:
if not Onlyrefresh:
if int(redis_client.get(f"worker_{self.pid}")) == 0: # 说明self中的db和磁盘中的db不同步
with open(self.db_path, "rb") as f:
representations = pickle.load(f)
if representations == []:
self.db_embeddings, self.db_identities = None, None
if representations == []:
self.db_embeddings, self.db_identities = None, None
else:
self.db_embeddings = np.array([rep["embedding"] for rep in representations], dtype=np.float32)
self.db_identities = [os.path.splitext(os.path.basename(rep["identity"]))[0] for rep in representations]
redis_client.set(f"worker_{self.pid}", 1) # 同步完毕
self.db_embeddings = np.array([rep["embedding"] for rep in representations], dtype=np.float32)
self.db_identities = [os.path.splitext(os.path.basename(rep["identity"]))[0] for rep in
representations]
redis_client.set(f"worker_{self.pid}", 1) # 同步完毕
img = load_img(img_path)
img = self.rotadjust(img) # 调整角度
if opt in ["add","replace"]:
if opt == "add" and self.db_identities is not None and identity in self.db_identities:
raise UpdatedbError("30019") # push: identity has exist. to pretect the db, reject opt of this time
img = self.rotadjust(img) # 调整角度
if opt in ["add", "replace"]:
if opt == "add" and self.db_identities is not None and identity in self.db_identities:
raise UpdatedbError(
"30019") # push: identity has exist. to pretect the db, reject opt of this time
else:
detect_result = self.fd.detect(img)
if len(detect_result) == 0: # no face
if len(detect_result) == 0: # no face
logger.info(f"{img_path[:200]}: when update, no face in the image")
print(f"{img_path[:200]}: when update, no face in the image")
raise FaceError("30018") # push: no face in the image
else: # 获取最大的face,然后进行check
raise FaceError("30018") # push: no face in the image
else: # 获取最大的face,然后进行check
# H, W = img.shape[:2]
areas = [ box.area() for box in detect_result]
areas = [box.area() for box in detect_result]
max_idx = areas.index(max(areas))
facebox = detect_result[max_idx]
facebox = (facebox.x1, facebox.y1, facebox.x2, facebox.y2) # top_left point, bottom_right point
FaceError_number = self.check_face(img_path=img_path[:200], img=img, facebox=facebox, prefix='update')
facebox = (
facebox.x1, facebox.y1, facebox.x2, facebox.y2) # top_left point, bottom_right point
FaceError_number = self.check_face(img_path=img_path[:200], img=img, facebox=facebox,
prefix='update')
if FaceError_number != "30002":
raise FaceError(FaceError_number)
cv2.imwrite(os.path.join(self.db_dir, identity+'.jpg'),img,[cv2.IMWRITE_JPEG_QUALITY, 100]) # 如果file已经存在则会替换它
cv2.imwrite(os.path.join(self.db_dir, identity + '.jpg'), img,
[cv2.IMWRITE_JPEG_QUALITY, 100]) # 如果file已经存在则会替换它
elif opt == "delete":
try:
os.remove(os.path.join(self.db_dir, identity+'.jpg'))
try:
os.remove(os.path.join(self.db_dir, identity + '.jpg'))
except FileNotFoundError:
pass
pass
else:
raise NotImpltError("30007") # push: this updateDB type is not support
print("end the updateDB")
logger.info(f"end the updateDB")
self.refresh_database(check = Onlyrefresh) # 结束时刷新下db, 并通知别的进程dirty
self.refresh_database(check=Onlyrefresh) # 结束时刷新下db, 并通知别的进程dirty
except Exception as e:
logger.info(f"{e}")
rw_lock.release_write()
raise e
else:
raise e
else:
rw_lock.release_write()
return 0
def refresh_database(self, check = True):
def refresh_database(self, check=True):
# ensure db exist
os.makedirs(self.db_dir, exist_ok=True)
if not os.path.exists(self.db_path):
with open(self.db_path, "wb") as f:
pickle.dump([], f)
representations = [] # representations 最后要储存在db中
representations = [] # representations 最后要储存在db中
# Load the representations from the pickle file
with open(self.db_path, "rb") as f:
representations = pickle.load(f)
@ -436,22 +450,24 @@ class FaceHelper:
storage_images = list_images(self.db_dir)
# transform all images in storage_images to `.jpg`
for idx in range(len(storage_images)):
for idx in range(len(storage_images)):
img_path = storage_images[idx]
base_path, ext = os.path.splitext(img_path)
if ext == '.jpg':
continue
continue
iimg = cv2.imread(img_path)
cv2.imwrite(base_path+'.jpg', iimg, [cv2.IMWRITE_JPEG_QUALITY, 100])
storage_images[idx] = base_path+'.jpg'
cv2.imwrite(base_path + '.jpg', iimg, [cv2.IMWRITE_JPEG_QUALITY, 100])
storage_images[idx] = base_path + '.jpg'
must_save_pickle = False
new_images = []; old_images = []; replaced_images = []
new_images = [];
old_images = [];
replaced_images = []
new_images = list(set(storage_images) - set(pickle_images))
new_images = list(set(storage_images) - set(pickle_images))
old_images = list(set(pickle_images) - set(storage_images))
for current_representation in representations: # 找到被替换的images
for current_representation in representations: # 找到被替换的images
identity = current_representation["identity"]
if identity in old_images:
continue
@ -459,8 +475,8 @@ class FaceHelper:
beta_hash = find_image_hash(identity)
if alpha_hash != beta_hash:
# logger.debug(f"Even though {identity} represented before, it's replaced later.")
replaced_images.append(identity)
replaced_images.append(identity)
new_images = new_images + replaced_images
old_images = old_images + replaced_images
@ -481,37 +497,38 @@ class FaceHelper:
logger.info(f"{new_image}: when refresh, no face in the image, delete")
print(f"{new_image}: when refresh, no face in the image, delete")
else:
if len(detect_result) > 1:
if len(detect_result) > 1:
logger.info(f"{new_image}: find too many face, get and extract the biggest face in them")
else:
else:
logger.info(f"{new_image}: find one face, perfect!")
areas = [ box.area() for box in detect_result]
areas = [box.area() for box in detect_result]
max_idx = areas.index(max(areas))
facebox = detect_result[max_idx]
facebox = (facebox.x1, facebox.y1, facebox.x2, facebox.y2) # top_left point, bottom_right point
if check:
FaceError_number = self.check_face(img_path=new_image[:200], img=image, facebox=facebox, prefix='refreshdb')
if FaceError_number != "30002":
continue
facebox = (facebox.x1, facebox.y1, facebox.x2, facebox.y2) # top_left point, bottom_right point
landmarks5 = self.ld5er.inference(image, facebox) # return: [(),(),(),(),()] 左眼 右眼 鼻子 左嘴角 右嘴角
landmarks5 = [ [ld5[0],ld5[1]] for ld5 in landmarks5]
if check:
FaceError_number = self.check_face(img_path=new_image[:200], img=image, facebox=facebox,
prefix='refreshdb')
if FaceError_number != "30002":
continue
landmarks5 = self.ld5er.inference(image, facebox) # return: [(),(),(),(),()] 左眼 右眼 鼻子 左嘴角 右嘴角
landmarks5 = [[ld5[0], ld5[1]] for ld5 in landmarks5]
cropped_face = self.fa.align(image, landmarks5=landmarks5)
feature = self.fr.inference(cropped_face)
new_representations.append({
"identity": new_image,
"identity": new_image,
"hash": find_image_hash(new_image),
"embedding": feature,
"detected_face_base64": encoder_img2base64(cropped_face),
})
representations += new_representations
must_save_pickle = True
if must_save_pickle:
must_save_pickle = True
if must_save_pickle:
print("must save the pickle")
with open(self.db_path, "wb") as f:
pickle.dump(representations, f)
@ -521,14 +538,14 @@ class FaceHelper:
# 保证db_dir 中的图片和self.db["identity"] 一致
remove_images = list(set(storage_images) - set([rep["identity"] for rep in representations]))
for remove_img in remove_images:
try:
for remove_img in remove_images:
try:
# os.remove(remove_img)
fname = os.path.basename(remove_img)
fname = os.path.basename(remove_img)
# os.rename( remove_img, os.path.join(self.db_dir, "..","images","remove",fname) )
except FileNotFoundError:
pass
pass
if int(redis_client.get(f"worker_{self.pid}")) == 0:
empty = False
if len(representations) <= 0:
@ -539,22 +556,22 @@ class FaceHelper:
self.db_embeddings = np.array([rep["embedding"] for rep in representations], dtype=np.float32)
self.db_identities = [os.path.splitext(os.path.basename(rep["identity"]))[0] for rep in representations]
redis_client.set(f"worker_{self.pid}", 1) # 当前进程已更新
if empty:
if empty:
logger.info("no face in the database")
raise FileError("30003") # push: no face in db
def rotadjust(self, img: np.ndarray):
image = img.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 转换为灰度图像
image = cv2.resize(image, (256, 256)) # resize (256,256)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 转换为灰度图像
image = cv2.resize(image, (256, 256)) # resize (256,256)
# 中心裁剪到 224x224
start = (256 - 224) // 2
image = image[start:start+224, start:start+224]
image = image[start:start + 224, start:start + 224]
# 将单通道灰度图像转换为三通道
image = np.stack((image,)*3, axis=-1)
image = np.stack((image,) * 3, axis=-1)
# 转换为符合 ONNX 需要的格式
image = image.astype(np.float32) / 255.0 # 归一化
image = image.transpose(2, 0, 1) # 将图像从 HWC 格式转换为 CHW 格式
@ -562,7 +579,7 @@ class FaceHelper:
inputs = {self.rotclsifer.get_inputs()[0].name: image}
probs = self.rotclsifer.run(None, inputs)
label = np.argmax(probs[0][0]) # 推理得到的逆时针旋转角度 [0,90,180,270]
if label == 1:
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
@ -578,15 +595,15 @@ class FaceHelper:
print("img turn right, use `cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)` to repair")
return img
def get_feature(self, img: np.ndarray):
time.sleep(0.08)
# assert img.shape[0] == img.shape[1] and img.shape[0] == 256 and img.shape[2] == 3
img = cv2.resize( img, (256,256) )
img = cv2.resize(img, (256, 256))
input_feed = {}
# crop_img = cv2.resize(img,(248,248))
# crop_img = crop_img[...,::-1]
crop_img = img[4:252, 4:252, :][...,::-1] # 注意要考虑 长或宽 < 248的情况
crop_img = img[4:252, 4:252, :][..., ::-1] # 注意要考虑 长或宽 < 248的情况
input_data = crop_img.transpose((2, 0, 1))
# resize_img = cv2.resize(img, (248, 248))
# input_data = resize_img.transpose((2, 0, 1))
@ -599,7 +616,7 @@ class FaceHelper:
# 2 normalization feature
norm = temp_result / np.linalg.norm(temp_result, axis=1)
return norm.flatten()
def check_face(self, img_path, img, facebox, prefix="update"):
H, W = img.shape[:2]
if facebox[0] < 0 or facebox[2] >= W:
@ -609,25 +626,92 @@ class FaceHelper:
if facebox[1] < 0 or facebox[3] >= H:
logger.info(f"{img_path}: when {prefix}, face shifted top/bottom")
print(f"{img_path}: when {prefix}, face shifted top/bottom")
return "30011" # face shifted top/bottom, partially not captured.
face_img = img[ max(0,int(facebox[1])):int(facebox[3]), max(0,int(facebox[0])):int(facebox[2]) ]
return "30011" # face shifted top/bottom, partially not captured.
face_img = img[max(0, int(facebox[1])):int(facebox[3]), max(0, int(facebox[0])):int(facebox[2])]
if not self.qc.check_bright(face_img):
logger.info(f"{img_path}: when {prefix}, bad bright face in the image")
print(f"{img_path}: when {prefix}, bad bright face in the image")
return "30009" # bad bright face in the image
return "30009" # bad bright face in the image
if not self.qc.check_resolution(face_img):
logger.info(f"{img_path}: when {prefix}, too small resolution of face in the image")
print(f"{img_path}: when {prefix}, too small resolution of face in the image")
return "30016" # small face in the image
return "30016" # small face in the image
pose = self.qpose.check(face_img)
if pose != "frontFace":
logger.info(f"{img_path}: when {prefix}, {pose} in the image")
print(f"{img_path}: when {prefix}, {pose} in the image")
dictt = {"rightFace": "30012", "leftFace": "30013", "upFace": "30014", "downFace": "30015"}
return dictt[pose] # pose of face in the image
return dictt[pose] # pose of face in the image
if not self.qclarity.check(face_img):
logger.info(f"{img_path}: when {prefix}, bad clarity of face in the image")
print(f"{img_path}: when {prefix}, bad clarity of face in the image")
return "30017" # poor clarity of face in the image
return "30017" # poor clarity of face in the image
return "30002"
return "30002"
def imageComparison(self, firstImage, secondImage):
rw_lock.acquire_read()
try:
# Load and adjust images
imageFirst = self.rotadjust(load_img(firstImage))
imageSecond = self.rotadjust(load_img(secondImage))
print("1")
# Detect faces
detect_result_first = self.fd.detect(imageFirst)
detect_result_second = self.fd.detect(imageSecond)
detect_result_first = [(box.x1, box.y1, box.x2, box.y2) for box in detect_result_first]
detect_result_second = [(box.x1, box.y1, box.x2, box.y2) for box in detect_result_second]
if len(detect_result_first) < 1:
raise FaceError("30018")
if len(detect_result_second) < 1:
raise FaceError("30018")
if len(detect_result_first) > 1:
raise FaceError("30008")
if len(detect_result_second) > 1:
raise FaceError("30008")
# Validate detected faces
self._validate_face_detection(detect_result_first, "firstImage")
self._validate_face_detection(detect_result_second, "secondImage")
# 提取人脸特征
features_first = self._extract_single_feature(imageFirst, detect_result_first)
features_second = self._extract_single_feature(imageSecond, detect_result_second)
# Calculate the Euclidean distance between features
distance = calculate_euclidean_distance(features_first, features_second)
# Compare distance with threshold
return distance >= 0.75
except Exception as e:
logger.error(f"Error occurred: {e}")
raise e
finally:
rw_lock.release_read()
def _validate_face_detection(self, detect_result, image_name):
if not detect_result:
logger.warning(f"No face detected in the image: {image_name}")
raise FaceError("30018") # no face in the image
def extract_features(self, image, detect_result):
return [
self._extract_single_feature(image, facebox)
for facebox in detect_result
]
def _extract_single_feature(self, image, detect_result):
for facebox in detect_result:
landmarks5 = self.ld5er.inference(image, facebox)
landmarks5 = np.array([[ld5[0], ld5[1]] for ld5 in landmarks5])
cropped_face = self.fa.align(image, landmarks5=landmarks5)
return self.fr.inference(cropped_face)
def calculate_euclidean_distance(features1, features2):
if len(features1) != len(features2):
raise ValueError("Feature arrays must have the same length")
# Calculate the Euclidean distance
total = 0
for f1, f2 in zip(features1, features2):
total += f1 * f2
return total

View File

@ -0,0 +1,36 @@
@echo off
setlocal enabledelayedexpansion
rem 从 redis.yaml 文件中读取 Redis 相关配置信息
set "REDIS_CONFIG="
for /f "delims=" %%i in (redis.yaml) do (
set "REDIS_CONFIG=!REDIS_CONFIG! %%i"
)
rem 工作进程数 workers
set "NUM_WORKERS=%3"
rem 提取 Redis 配置信息
for %%i in (!REDIS_CONFIG!) do (
for /f "tokens=1,* delims==" %%j in (%%i) do (
if "%%j"=="REDIS_HOST" set "REDIS_HOST=%%k"
if "%%j"=="REDIS_PORT" set "REDIS_PORT=%%k"
if "%%j"=="REDIS_PASSWORD" set "REDIS_PASSWORD=%%k"
)
)
echo num_workers: %NUM_WORKERS%
echo redis_host: %REDIS_HOST%; redis_port: %REDIS_PORT%
rem 运行 Python 脚本
set REDIS_HOST=%REDIS_HOST%
set REDIS_PORT=%REDIS_PORT%
set REDIS_PASSWORD=%REDIS_PASSWORD%
set NUM_WORKERS=%NUM_WORKERS%
rem 确保 Python 脚本能够正确读取环境变量
python flushredis.py
rem 启动 uvicorn
set MAX_READERS=%4%
uvicorn webmain:app --host %1 --port %2 --workers %NUM_WORKERS%

View File

@ -1,83 +1,154 @@
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from process import FaceHelper, FileError, ErrorMsg
from process import FaceHelper, FileError, ErrorMsg
import threading
import os
import time
import time
# 创建 FastAPI 实例
app = FastAPI()
# 创建线程锁,确保在多线程环境中对共享资源的访问是线程安全的
lock = threading.Lock()
# 全局变量 facehelper稍后将用于管理人脸识别相关操作
facehelper = None
facehelper = FaceHelper(
db_dir="./dbface",
)
class face(BaseModel):
img:str
def get_facehelper():
"""
懒加载模式初始化 facehelper 对象
只有在第一次调用时才会创建 FaceHelper 实例
使用线程锁以确保 facehelper 实例在多线程环境下的安全初始化
"""
global facehelper
if facehelper is None:
with lock:
if facehelper is None:
facehelper = FaceHelper(db_dir="./dbface")
return facehelper
# 定义请求模型,用于接收前端传递的图像数据
class FaceRequest(BaseModel):
img: str
# 定义请求模型,用于接收前端传递的数据库操作数据
class DBFaceRequest(BaseModel):
img: str
optMode: str
uniqueKey: str
# 接两图片对比的参数
class ImageComparison(BaseModel):
firstImage: str
secondImage: str
class dbface(BaseModel):
img:str
optMode:str
uniqueKey:str
@app.post("/refreshdb")
@app.post("/refreshdb/")
def refresh():
global facehelper
try:
"""
刷新人脸数据库的接口
该接口将调用 facehelper updateDB 方法刷新数据库内容
"""
facehelper = get_facehelper()
try:
# 加锁确保数据库更新操作的线程安全性
with lock:
facehelper.updateDB(None, None, None, Onlyrefresh=True)
except FileError as e:
# return {"status":e.code, "detail":f"{e}"}
return {'code': e.code, 'msg': f"{e}", 'data': 'None'}
else:
return {'code': "30002", 'msg': ErrorMsg['30002'], 'data': 'None'}
except FileError as e:
# 处理文件错误,返回相应的错误代码和信息
return {'code': e.code, 'msg': str(e), 'data': None}
else:
# 成功刷新数据库后返回相应的状态码和消息
return {'code': "30002", 'msg': ErrorMsg['30002'], 'data': None}
@app.post("/facerecognition")
@app.post("/facerecognition/")
def faceRecognition(input:face):
start = time.time()
global facehelper
try:
def faceRecognition(input: FaceRequest):
"""
进行人脸识别的接口
接收前端传递的图像数据并使用 facehelper 进行人脸识别
返回识别结果和运行时间
"""
facehelper = get_facehelper()
start = time.time() # 记录开始时间
try:
# 调用 facehelper 的 faceRecognition 方法进行人脸识别
ret_data = facehelper.faceRecognition(input.img)
print("finished recognition...")
end = time.time()
print("runtime: ", end-start)
except Exception as e:
return {'code': e.code, 'msg': f"{e}", 'data': 'None'}
# return {"status":f"{e.code}", "detail":f"{e}"}
else:
return ret_data
return {"status":1, "name":identity, "resImg":res_img_base64}
# 处理异常,返回相应的错误代码和信息
raise HTTPException(status_code=400, detail={'code': e.code, 'msg': str(e), 'data': None})
finally:
end = time.time() # 记录结束时间
# 打印运行时间
print(f"Recognition finished. Runtime: {end - start:.2f} seconds")
# 返回识别结果
return ret_data
@app.post("/featuredetect")
@app.post("/featuredetect/")
def featureDetect(input:face):
start = time.time()
global facehelper
try:
def featureDetect(input: FaceRequest):
"""
特征检测接口
接收前端传递的图像数据并使用 facehelper 进行特征检测
返回检测结果和运行时间
"""
facehelper = get_facehelper()
start = time.time() # 记录开始时间
try:
# 调用 facehelper 的 featureDetect 方法进行特征检测
ret_data = facehelper.featureDetect(input.img)
print("finished featuredetect...")
end = time.time()
print("runtime: ", end-start)
except Exception as e:
return {'code': e.code, 'msg': f"{e}", 'data': 'None'}
# return {"status":f"{e.code}", "detail":f"{e}"}
else:
return ret_data
# 处理异常,返回相应的错误代码和信息
raise HTTPException(status_code=400, detail={'code': e.code, 'msg': str(e), 'data': None})
finally:
end = time.time() # 记录结束时间
# 打印运行时间
print(f"Feature detection finished. Runtime: {end - start:.2f} seconds")
# 返回检测结果
return ret_data
@app.post("/updatedb")
@app.post("/updatedb/")
def updateDB(input:dbface):
global facehelper
# input.uniqueKey = os.path.splitext(os.path.basename(input.uniqueKey))[0]
try:
def updateDB(input: DBFaceRequest):
"""
更新人脸数据库的接口
接收前端传递的数据库操作数据并使用 facehelper 更新数据库
"""
facehelper = get_facehelper()
try:
# 加锁确保数据库更新操作的线程安全性
with lock:
facehelper.updateDB(input.img, input.optMode, input.uniqueKey)
except Exception as e:
return {'code': e.code, 'msg': f"{e}", 'data': 'None'}
# return {"status":f"{e.code}", "detail":f"{e}"}
else:
return {'code': "30002", 'msg': ErrorMsg['30002'], 'data': 'None'}
# 处理异常,返回相应的错误代码和信息
raise HTTPException(status_code=400, detail={'code': e.code, 'msg': str(e), 'data': None})
else:
# 成功更新数据库后返回相应的状态码和消息
return {'code': "30002", 'msg': ErrorMsg['30002'], 'data': None}
@app.post("/imageComparison")
def imageComparison(input: ImageComparison):
"""
更新人脸数据库的接口
接收前端传递的数据库操作数据并使用 facehelper 更新数据库
"""
facehelper = get_facehelper()
try:
# 加锁确保数据库更新操作的线程安全性
with lock:
state = facehelper.imageComparison(input.firstImage, input.secondImage)
except Exception as e:
# 处理异常,返回相应的错误代码和信息
raise HTTPException(status_code=400, detail={'code': e.code, 'msg': str(e), 'data': None})
else:
# 成功更新数据库后返回相应的状态码和消息
return {'code': "30002", 'msg': ErrorMsg['30002'], 'data': bool(state)}