129 lines
3.6 KiB
Python
129 lines
3.6 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
人脸特征提取微服务
|
|||
|
|
仅提供特征提取 API,供 Java 后端调用
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import uvicorn
|
|||
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException
|
|||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
from typing import List, Optional
|
|||
|
|
import numpy as np
|
|||
|
|
import cv2
|
|||
|
|
import logging
|
|||
|
|
from face_feature_extractor import FaceFeatureExtractor
|
|||
|
|
|
|||
|
|
# 配置日志
|
|||
|
|
logging.basicConfig(
|
|||
|
|
level=logging.INFO,
|
|||
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|||
|
|
)
|
|||
|
|
logger = logging.getLogger("FeatureServer")
|
|||
|
|
|
|||
|
|
# 创建 FastAPI 应用
|
|||
|
|
app = FastAPI(
|
|||
|
|
title="Face Feature Extraction Microservice",
|
|||
|
|
description="Dedicated service for extracting face features",
|
|||
|
|
version="1.0.0"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 允许跨域
|
|||
|
|
app.add_middleware(
|
|||
|
|
CORSMiddleware,
|
|||
|
|
allow_origins=["*"],
|
|||
|
|
allow_credentials=True,
|
|||
|
|
allow_methods=["*"],
|
|||
|
|
allow_headers=["*"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 全局特征提取器实例
|
|||
|
|
extractor = None
|
|||
|
|
|
|||
|
|
def get_extractor():
|
|||
|
|
"""懒加载特征提取器"""
|
|||
|
|
global extractor
|
|||
|
|
if extractor is None:
|
|||
|
|
logger.info("Initializing FaceFeatureExtractor...")
|
|||
|
|
extractor = FaceFeatureExtractor()
|
|||
|
|
logger.info("FaceFeatureExtractor initialized.")
|
|||
|
|
return extractor
|
|||
|
|
|
|||
|
|
# 响应模型
|
|||
|
|
class ExtractFeatureResponse(BaseModel):
|
|||
|
|
success: bool
|
|||
|
|
message: str
|
|||
|
|
feature: Optional[List[float]] = None
|
|||
|
|
feature_dim: Optional[int] = None
|
|||
|
|
processing_time: Optional[float] = None
|
|||
|
|
|
|||
|
|
def decode_image(file: UploadFile) -> Optional[np.ndarray]:
|
|||
|
|
"""解码上传的图片"""
|
|||
|
|
try:
|
|||
|
|
contents = file.file.read()
|
|||
|
|
nparr = np.frombuffer(contents, np.uint8)
|
|||
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|||
|
|
return image
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Image decode failed: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
@app.on_event("startup")
|
|||
|
|
async def startup_event():
|
|||
|
|
"""启动时预加载模型"""
|
|||
|
|
get_extractor()
|
|||
|
|
|
|||
|
|
@app.get("/health")
|
|||
|
|
async def health_check():
|
|||
|
|
return {"status": "healthy", "service": "Face Feature Extractor"}
|
|||
|
|
|
|||
|
|
@app.post("/api/extract_feature", response_model=ExtractFeatureResponse)
|
|||
|
|
async def extract_feature(image: UploadFile = File(...)):
|
|||
|
|
"""
|
|||
|
|
特征提取接口
|
|||
|
|
输入: 图片文件
|
|||
|
|
输出: 1024维特征向量
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
img = decode_image(image)
|
|||
|
|
if img is None:
|
|||
|
|
raise HTTPException(status_code=400, detail="Invalid image file")
|
|||
|
|
|
|||
|
|
ext = get_extractor()
|
|||
|
|
# 调用提取器
|
|||
|
|
result = ext.extract_features(img, return_all_faces=False, quality_filter=True)
|
|||
|
|
|
|||
|
|
if not result.success or not result.faces:
|
|||
|
|
return ExtractFeatureResponse(
|
|||
|
|
success=False,
|
|||
|
|
message=result.error_message or "No face detected",
|
|||
|
|
processing_time=result.processing_time
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 获取特征
|
|||
|
|
feature_vector = result.faces[0].feature.tolist()
|
|||
|
|
|
|||
|
|
return ExtractFeatureResponse(
|
|||
|
|
success=True,
|
|||
|
|
message="Success",
|
|||
|
|
feature=feature_vector,
|
|||
|
|
feature_dim=len(feature_vector),
|
|||
|
|
processing_time=result.processing_time
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Extraction failed: {e}", exc_info=True)
|
|||
|
|
return ExtractFeatureResponse(
|
|||
|
|
success=False,
|
|||
|
|
message=f"Server error: {str(e)}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
import argparse
|
|||
|
|
parser = argparse.ArgumentParser(description='Face Feature Extraction Microservice')
|
|||
|
|
parser.add_argument('--port', type=int, default=8000, help='Service port')
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
logger.info(f"Starting Feature Server on port {args.port}...")
|
|||
|
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|