Face_reg_app/FaceFeatureExtractorAPI/database.py

158 lines
5.3 KiB
Python

# -*- coding: utf-8 -*-
"""
数据库模型定义
使用SQLite进行用户数据存储
"""
from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, Boolean, LargeBinary
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from datetime import datetime
import json
import numpy as np
Base = declarative_base()
class User(Base):
"""用户表 - 存储注册的人脸信息"""
__tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(100), nullable=False, unique=True, index=True) # 姓名(唯一)
age = Column(Integer, nullable=True) # 年龄(可选)
feature_vector = Column(LargeBinary, nullable=False) # 1024维特征向量(二进制存储)
feature_dim = Column(Integer, default=1024) # 特征维度
photo_filename = Column(String(255), nullable=True) # 照片文件名
created_at = Column(DateTime, default=datetime.now) # 创建时间
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) # 更新时间
is_active = Column(Boolean, default=True) # 是否激活
def __repr__(self):
return f"<User(id={self.id}, name='{self.name}', age={self.age})>"
def get_feature_array(self):
"""从二进制恢复为numpy数组"""
return np.frombuffer(self.feature_vector, dtype=np.float32)
def set_feature_array(self, feature_array):
"""将numpy数组转换为二进制存储"""
if isinstance(feature_array, list):
feature_array = np.array(feature_array, dtype=np.float32)
self.feature_vector = feature_array.tobytes()
self.feature_dim = len(feature_array)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'name': self.name,
'age': self.age,
'feature_dim': self.feature_dim,
'photo_filename': self.photo_filename,
'created_at': self.created_at.isoformat() if self.created_at else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
'is_active': self.is_active
}
def to_dict_with_feature(self):
"""转换为字典(包含特征向量)"""
data = self.to_dict()
data['feature_vector'] = self.get_feature_array().tolist()
return data
class RecognitionLog(Base):
"""识别记录表 - 存储每次识别的日志"""
__tablename__ = 'recognition_logs'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, nullable=True) # 识别到的用户ID(未识别则为空)
user_name = Column(String(100), nullable=True) # 识别到的用户姓名
device_id = Column(String(100), nullable=True) # 设备ID
similarity = Column(Float, nullable=True) # 相似度
threshold = Column(Float, default=0.7) # 使用的阈值
is_recognized = Column(Boolean, default=False) # 是否识别成功
recognized_at = Column(DateTime, default=datetime.now) # 识别时间
photo_filename = Column(String(255), nullable=True) # 识别时的照片
def __repr__(self):
return f"<RecognitionLog(id={self.id}, user_name='{self.user_name}', similarity={self.similarity})>"
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'user_id': self.user_id,
'user_name': self.user_name,
'device_id': self.device_id,
'similarity': self.similarity,
'threshold': self.threshold,
'is_recognized': self.is_recognized,
'recognized_at': self.recognized_at.isoformat() if self.recognized_at else None,
'photo_filename': self.photo_filename
}
# 数据库配置
DATABASE_URL = "sqlite:///./face_recognition.db"
# 创建引擎
engine = create_engine(
DATABASE_URL,
connect_args={"check_same_thread": False}, # SQLite特定配置
echo=False # 设置为True可以看到SQL语句
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_database():
"""初始化数据库,创建所有表"""
Base.metadata.create_all(bind=engine)
print("✅ 数据库初始化完成")
def get_db():
"""获取数据库会话"""
db = SessionLocal()
try:
yield db
finally:
db.close()
if __name__ == "__main__":
# 测试: 创建数据库
init_database()
# 测试: 创建一个用户
db = SessionLocal()
test_feature = np.random.randn(1024).astype(np.float32)
test_feature = test_feature / np.linalg.norm(test_feature) # 归一化
test_user = User(
name="测试用户",
age=25,
)
test_user.set_feature_array(test_feature)
try:
db.add(test_user)
db.commit()
print(f"✅ 创建测试用户成功: {test_user}")
# 测试读取
user = db.query(User).filter(User.name == "测试用户").first()
if user:
print(f"✅ 读取用户成功: {user}")
print(f" 特征维度: {user.feature_dim}")
feature_array = user.get_feature_array()
print(f" 特征数组形状: {feature_array.shape}")
except Exception as e:
print(f"❌ 错误: {e}")
db.rollback()
finally:
db.close()