632 lines
19 KiB
Python
632 lines
19 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
人脸特征提取API服务 V2.0
|
|
提供完整的用户管理和识别功能
|
|
"""
|
|
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Depends, Query
|
|
from fastapi.responses import JSONResponse, FileResponse, HTMLResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Optional
|
|
from sqlalchemy.orm import Session
|
|
import cv2
|
|
import numpy as np
|
|
import base64
|
|
import logging
|
|
from io import BytesIO
|
|
import uvicorn
|
|
import os
|
|
from datetime import datetime
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
from face_feature_extractor import FaceFeatureExtractor
|
|
from database import User, RecognitionLog, get_db, init_database
|
|
|
|
# 配置日志
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 创建FastAPI应用
|
|
app = FastAPI(
|
|
title="人脸识别管理系统",
|
|
description="提供人脸特征提取、用户管理、识别功能",
|
|
version="2.0.0"
|
|
)
|
|
|
|
# 添加CORS中间件
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 创建必要的目录
|
|
UPLOAD_DIR = Path("uploads")
|
|
PHOTOS_DIR = UPLOAD_DIR / "photos"
|
|
LOGS_DIR = UPLOAD_DIR / "logs"
|
|
|
|
for dir_path in [UPLOAD_DIR, PHOTOS_DIR, LOGS_DIR]:
|
|
dir_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# 全局特征提取器实例(单例模式)
|
|
extractor = None
|
|
|
|
|
|
def get_extractor():
|
|
"""获取特征提取器实例(懒加载)"""
|
|
global extractor
|
|
if extractor is None:
|
|
logger.info("初始化人脸特征提取器...")
|
|
extractor = FaceFeatureExtractor()
|
|
logger.info("人脸特征提取器初始化完成")
|
|
return extractor
|
|
|
|
|
|
# ==================== 请求和响应模型 ====================
|
|
|
|
class ExtractFeatureResponse(BaseModel):
|
|
"""特征提取响应"""
|
|
success: bool
|
|
message: str
|
|
feature: Optional[List[float]] = None
|
|
feature_dim: Optional[int] = None
|
|
quality_passed: Optional[bool] = None
|
|
processing_time: Optional[float] = None
|
|
|
|
|
|
class UserRegisterRequest(BaseModel):
|
|
"""用户注册请求(JSON)"""
|
|
name: str = Field(..., description="用户姓名")
|
|
age: Optional[int] = Field(None, description="年龄")
|
|
|
|
|
|
class UserRegisterResponse(BaseModel):
|
|
"""用户注册响应"""
|
|
success: bool
|
|
message: str
|
|
user_id: Optional[int] = None
|
|
user: Optional[dict] = None # 包含 feature_vector 等完整信息
|
|
|
|
|
|
class UserListResponse(BaseModel):
|
|
"""用户列表响应"""
|
|
success: bool
|
|
message: str
|
|
total: int
|
|
users: List[dict]
|
|
|
|
|
|
class RecognizeResponse(BaseModel):
|
|
"""识别响应"""
|
|
success: bool
|
|
message: str
|
|
recognized: bool = False
|
|
user_id: Optional[int] = None
|
|
user_name: Optional[str] = None
|
|
similarity: Optional[float] = None
|
|
threshold: float = 0.7
|
|
processing_time: Optional[float] = None
|
|
|
|
|
|
class SyncUsersResponse(BaseModel):
|
|
"""同步用户响应"""
|
|
success: bool
|
|
message: str
|
|
total: int
|
|
users: List[dict] # 包含特征向量的完整用户数据
|
|
|
|
|
|
# ==================== 工具函数 ====================
|
|
|
|
def decode_image_from_upload(file: UploadFile) -> Optional[np.ndarray]:
|
|
"""从上传文件解码图像"""
|
|
try:
|
|
contents = file.file.read()
|
|
nparr = np.frombuffer(contents, np.uint8)
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
return image
|
|
except Exception as e:
|
|
logger.error(f"图像解码失败: {e}")
|
|
return None
|
|
|
|
|
|
def save_uploaded_photo(file: UploadFile, user_name: str) -> str:
|
|
"""保存上传的照片"""
|
|
try:
|
|
# 生成文件名
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filename = f"{user_name}_{timestamp}.jpg"
|
|
filepath = PHOTOS_DIR / filename
|
|
|
|
# 保存文件
|
|
with open(filepath, "wb") as buffer:
|
|
shutil.copyfileobj(file.file, buffer)
|
|
|
|
logger.info(f"照片已保存: {filepath}")
|
|
return filename
|
|
except Exception as e:
|
|
logger.error(f"保存照片失败: {e}")
|
|
return None
|
|
|
|
|
|
def calculate_similarity(feature1: np.ndarray, feature2: np.ndarray) -> float:
|
|
"""计算两个特征向量的余弦相似度"""
|
|
similarity = float(np.dot(feature1, feature2))
|
|
return similarity
|
|
|
|
|
|
# ==================== API接口 ====================
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""应用启动时初始化数据库"""
|
|
init_database()
|
|
logger.info("✅ 数据库初始化完成")
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""根路径"""
|
|
return {
|
|
"message": "人脸识别管理系统 API",
|
|
"version": "2.0.0",
|
|
"endpoints": {
|
|
"特征提取": "/api/extract_feature",
|
|
"用户注册": "/api/users/register",
|
|
"用户列表": "/api/users/list",
|
|
"用户详情": "/api/users/{user_id}",
|
|
"删除用户": "/api/users/{user_id}",
|
|
"人脸识别": "/api/recognize",
|
|
"同步用户": "/api/users/sync",
|
|
"管理后台": "/admin",
|
|
"API文档": "/docs"
|
|
}
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check(db: Session = Depends(get_db)):
|
|
"""健康检查"""
|
|
try:
|
|
ext = get_extractor()
|
|
user_count = db.query(User).filter(User.is_active == True).count()
|
|
return {
|
|
"status": "healthy",
|
|
"extractor_loaded": ext is not None,
|
|
"database_connected": True,
|
|
"total_users": user_count
|
|
}
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=503,
|
|
content={
|
|
"status": "unhealthy",
|
|
"error": str(e)
|
|
}
|
|
)
|
|
|
|
|
|
# ==================== 原有接口(保持兼容) ====================
|
|
|
|
@app.post("/api/extract_feature", response_model=ExtractFeatureResponse)
|
|
async def extract_feature(image: UploadFile = File(...)):
|
|
"""
|
|
人脸特征提取接口
|
|
上传一张人脸图像,返回提取的特征向量
|
|
"""
|
|
try:
|
|
img = decode_image_from_upload(image)
|
|
if img is None:
|
|
raise HTTPException(status_code=400, detail="无法解析图像文件")
|
|
|
|
ext = get_extractor()
|
|
result = ext.extract_features(img, return_all_faces=False, quality_filter=True)
|
|
|
|
if not result.success or not result.faces:
|
|
return ExtractFeatureResponse(
|
|
success=False,
|
|
message=result.error_message or "未检测到合格的人脸",
|
|
processing_time=result.processing_time
|
|
)
|
|
|
|
face_info = result.faces[0]
|
|
feature_vector = face_info.feature.tolist()
|
|
|
|
return ExtractFeatureResponse(
|
|
success=True,
|
|
message="特征提取成功",
|
|
feature=feature_vector,
|
|
feature_dim=len(feature_vector),
|
|
quality_passed=face_info.quality_scores['overall']['passed'],
|
|
processing_time=result.processing_time
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"特征提取失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}")
|
|
|
|
|
|
# ==================== 用户管理接口 ====================
|
|
|
|
@app.post("/api/users/register", response_model=UserRegisterResponse)
|
|
async def register_user(
|
|
name: str = Form(..., description="用户姓名"),
|
|
age: Optional[int] = Form(None, description="年龄"),
|
|
image: UploadFile = File(..., description="人脸照片"),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""
|
|
用户注册接口
|
|
上传人脸照片和姓名,提取特征并存储到数据库
|
|
"""
|
|
try:
|
|
# 验证输入
|
|
if not name or not name.strip():
|
|
raise HTTPException(status_code=400, detail="用户姓名不能为空")
|
|
|
|
name = name.strip()
|
|
|
|
# 检查用户名是否已存在(仅检查激活用户)
|
|
existing_user = db.query(User).filter(
|
|
User.name == name,
|
|
User.is_active == True
|
|
).first()
|
|
|
|
if existing_user:
|
|
raise HTTPException(status_code=400, detail=f"用户 '{name}' 已存在")
|
|
|
|
# 检查是否存在已删除的同名用户
|
|
deleted_user = db.query(User).filter(
|
|
User.name == name,
|
|
User.is_active == False
|
|
).first()
|
|
|
|
# 提取人脸特征
|
|
img = decode_image_from_upload(image)
|
|
if img is None:
|
|
raise HTTPException(status_code=400, detail="无法解析图像文件")
|
|
|
|
ext = get_extractor()
|
|
result = ext.extract_features(img, return_all_faces=False, quality_filter=True)
|
|
|
|
if not result.success or not result.faces:
|
|
return UserRegisterResponse(
|
|
success=False,
|
|
message=result.error_message or "未检测到合格的人脸"
|
|
)
|
|
|
|
# 获取特征向量
|
|
face_info = result.faces[0]
|
|
feature_vector = face_info.feature
|
|
|
|
# 保存照片
|
|
image.file.seek(0) # 重置文件指针
|
|
photo_filename = save_uploaded_photo(image, name)
|
|
|
|
# 如果是重新激活已删除的用户
|
|
if deleted_user:
|
|
deleted_user.age = age
|
|
deleted_user.photo_filename = photo_filename
|
|
deleted_user.set_feature_array(feature_vector)
|
|
deleted_user.is_active = True
|
|
deleted_user.updated_at = datetime.now()
|
|
|
|
db.commit()
|
|
db.refresh(deleted_user)
|
|
|
|
logger.info(f"✅ 用户重新激活: {name} (ID: {deleted_user.id})")
|
|
|
|
return UserRegisterResponse(
|
|
success=True,
|
|
message=f"用户 '{name}' 重新注册成功",
|
|
user_id=deleted_user.id,
|
|
user=deleted_user.to_dict_with_feature()
|
|
)
|
|
|
|
# 创建新用户记录
|
|
new_user = User(
|
|
name=name,
|
|
age=age,
|
|
photo_filename=photo_filename
|
|
)
|
|
new_user.set_feature_array(feature_vector)
|
|
|
|
db.add(new_user)
|
|
db.commit()
|
|
db.refresh(new_user)
|
|
|
|
logger.info(f"✅ 用户注册成功: {name} (ID: {new_user.id})")
|
|
|
|
return UserRegisterResponse(
|
|
success=True,
|
|
message=f"用户 '{name}' 注册成功",
|
|
user_id=new_user.id,
|
|
user=new_user.to_dict_with_feature()
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
db.rollback()
|
|
logger.error(f"用户注册失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"注册失败: {str(e)}")
|
|
|
|
|
|
@app.get("/api/users/list", response_model=UserListResponse)
|
|
async def list_users(
|
|
skip: int = Query(0, description="跳过记录数"),
|
|
limit: int = Query(100, description="返回记录数"),
|
|
active_only: bool = Query(True, description="仅返回激活用户"),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""
|
|
获取用户列表
|
|
支持分页和筛选
|
|
"""
|
|
try:
|
|
query = db.query(User)
|
|
if active_only:
|
|
query = query.filter(User.is_active == True)
|
|
|
|
total = query.count()
|
|
users = query.offset(skip).limit(limit).all()
|
|
|
|
user_list = [user.to_dict() for user in users]
|
|
|
|
return UserListResponse(
|
|
success=True,
|
|
message="查询成功",
|
|
total=total,
|
|
users=user_list
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"查询用户列表失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}")
|
|
|
|
|
|
# ==================== 同步接口(供Android下载用户数据) ====================
|
|
|
|
@app.get("/api/users/sync", response_model=SyncUsersResponse)
|
|
async def sync_users(db: Session = Depends(get_db)):
|
|
"""
|
|
同步用户数据
|
|
返回所有激活用户的完整信息(包含特征向量),供Android端下载
|
|
"""
|
|
try:
|
|
users = db.query(User).filter(User.is_active == True).all()
|
|
|
|
user_list = [user.to_dict_with_feature() for user in users]
|
|
|
|
logger.info(f"📥 用户数据同步请求: {len(user_list)} 个用户")
|
|
|
|
return SyncUsersResponse(
|
|
success=True,
|
|
message="同步成功",
|
|
total=len(user_list),
|
|
users=user_list
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"同步用户数据失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"同步失败: {str(e)}")
|
|
|
|
|
|
@app.get("/api/users/{user_id}")
|
|
async def get_user(user_id: int, db: Session = Depends(get_db)):
|
|
"""获取单个用户详情"""
|
|
try:
|
|
user = db.query(User).filter(User.id == user_id).first()
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail="用户不存在")
|
|
|
|
return {
|
|
"success": True,
|
|
"user": user.to_dict()
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"查询用户失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}")
|
|
|
|
|
|
@app.delete("/api/users/{user_id}")
|
|
async def delete_user(user_id: int, db: Session = Depends(get_db)):
|
|
"""删除用户(软删除)"""
|
|
try:
|
|
user = db.query(User).filter(User.id == user_id).first()
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail="用户不存在")
|
|
|
|
# 软删除
|
|
user.is_active = False
|
|
db.commit()
|
|
|
|
logger.info(f"✅ 用户已删除: {user.name} (ID: {user_id})")
|
|
|
|
return {
|
|
"success": True,
|
|
"message": f"用户 '{user.name}' 已删除"
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
db.rollback()
|
|
logger.error(f"删除用户失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")
|
|
|
|
|
|
# ==================== 识别接口 ====================
|
|
|
|
@app.post("/api/recognize", response_model=RecognizeResponse)
|
|
async def recognize_face(
|
|
image: UploadFile = File(..., description="待识别的人脸照片"),
|
|
device_id: Optional[str] = Form(None, description="设备ID"),
|
|
threshold: float = Form(0.7, description="识别阈值"),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""
|
|
人脸识别接口
|
|
上传人脸照片,在数据库中查找匹配的用户
|
|
"""
|
|
try:
|
|
# 提取人脸特征
|
|
img = decode_image_from_upload(image)
|
|
if img is None:
|
|
raise HTTPException(status_code=400, detail="无法解析图像文件")
|
|
|
|
ext = get_extractor()
|
|
result = ext.extract_features(img, return_all_faces=False, quality_filter=True)
|
|
|
|
if not result.success or not result.faces:
|
|
return RecognizeResponse(
|
|
success=False,
|
|
message=result.error_message or "未检测到人脸",
|
|
recognized=False,
|
|
processing_time=result.processing_time
|
|
)
|
|
|
|
# 获取提取的特征
|
|
extracted_feature = result.faces[0].feature
|
|
|
|
# 查询所有激活用户
|
|
users = db.query(User).filter(User.is_active == True).all()
|
|
|
|
if not users:
|
|
return RecognizeResponse(
|
|
success=True,
|
|
message="数据库中没有注册用户",
|
|
recognized=False,
|
|
processing_time=result.processing_time
|
|
)
|
|
|
|
# 遍历所有用户计算相似度
|
|
max_similarity = 0.0
|
|
matched_user = None
|
|
|
|
for user in users:
|
|
user_feature = user.get_feature_array()
|
|
similarity = calculate_similarity(extracted_feature, user_feature)
|
|
|
|
if similarity > max_similarity:
|
|
max_similarity = similarity
|
|
matched_user = user
|
|
|
|
# 判断是否识别成功
|
|
recognized = max_similarity >= threshold
|
|
|
|
# 记录识别日志
|
|
log = RecognitionLog(
|
|
user_id=matched_user.id if recognized else None,
|
|
user_name=matched_user.name if recognized else None,
|
|
device_id=device_id,
|
|
similarity=float(max_similarity),
|
|
threshold=threshold,
|
|
is_recognized=recognized
|
|
)
|
|
db.add(log)
|
|
db.commit()
|
|
|
|
if recognized:
|
|
logger.info(f"✅ 识别成功: {matched_user.name} (相似度: {max_similarity:.4f})")
|
|
return RecognizeResponse(
|
|
success=True,
|
|
message=f"识别成功: {matched_user.name}",
|
|
recognized=True,
|
|
user_id=matched_user.id,
|
|
user_name=matched_user.name,
|
|
similarity=float(max_similarity),
|
|
threshold=threshold,
|
|
processing_time=result.processing_time
|
|
)
|
|
else:
|
|
logger.info(f"⚠️ 未识别: 最高相似度 {max_similarity:.4f} < 阈值 {threshold}")
|
|
return RecognizeResponse(
|
|
success=True,
|
|
message=f"未识别到注册用户 (最高相似度: {max_similarity:.4f})",
|
|
recognized=False,
|
|
similarity=float(max_similarity),
|
|
threshold=threshold,
|
|
processing_time=result.processing_time
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"识别失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"识别失败: {str(e)}")
|
|
|
|
|
|
# ==================== 统计接口 ====================
|
|
|
|
@app.get("/api/stats")
|
|
async def get_statistics(db: Session = Depends(get_db)):
|
|
"""获取系统统计信息"""
|
|
try:
|
|
total_users = db.query(User).filter(User.is_active == True).count()
|
|
total_logs = db.query(RecognitionLog).count()
|
|
recognized_logs = db.query(RecognitionLog).filter(RecognitionLog.is_recognized == True).count()
|
|
|
|
return {
|
|
"success": True,
|
|
"stats": {
|
|
"total_users": total_users,
|
|
"total_recognitions": total_logs,
|
|
"successful_recognitions": recognized_logs,
|
|
"recognition_rate": round(recognized_logs / total_logs * 100, 2) if total_logs > 0 else 0
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}")
|
|
|
|
|
|
# ==================== 管理后台页面 ====================
|
|
|
|
@app.get("/admin", response_class=HTMLResponse)
|
|
async def admin_page():
|
|
"""管理后台HTML页面"""
|
|
html_path = Path(__file__).parent / "admin.html"
|
|
if html_path.exists():
|
|
return FileResponse(html_path)
|
|
else:
|
|
return HTMLResponse(content="<h1>管理后台页面建设中...</h1><p>请访问 <a href='/docs'>/docs</a> 使用API文档</p>")
|
|
|
|
|
|
# ==================== 启动配置 ====================
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='人脸识别管理系统API服务')
|
|
parser.add_argument('--host', type=str, default='0.0.0.0', help='服务器地址')
|
|
parser.add_argument('--port', type=int, default=8000, help='服务器端口')
|
|
parser.add_argument('--reload', action='store_true', help='开启热重载')
|
|
|
|
args = parser.parse_args()
|
|
|
|
logger.info(f"🚀 启动服务: http://{args.host}:{args.port}")
|
|
logger.info(f"📚 API文档: http://{args.host}:{args.port}/docs")
|
|
logger.info(f"🎛️ 管理后台: http://{args.host}:{args.port}/admin")
|
|
|
|
uvicorn.run(
|
|
"app:app",
|
|
host=args.host,
|
|
port=args.port,
|
|
reload=args.reload
|
|
)
|