158 lines
5.3 KiB
Python
158 lines
5.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
数据库模型定义
|
|
使用SQLite进行用户数据存储
|
|
"""
|
|
|
|
from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, Boolean, LargeBinary
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker
|
|
from datetime import datetime
|
|
import json
|
|
import numpy as np
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class User(Base):
|
|
"""用户表 - 存储注册的人脸信息"""
|
|
__tablename__ = 'users'
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
name = Column(String(100), nullable=False, unique=True, index=True) # 姓名(唯一)
|
|
age = Column(Integer, nullable=True) # 年龄(可选)
|
|
feature_vector = Column(LargeBinary, nullable=False) # 1024维特征向量(二进制存储)
|
|
feature_dim = Column(Integer, default=1024) # 特征维度
|
|
photo_filename = Column(String(255), nullable=True) # 照片文件名
|
|
created_at = Column(DateTime, default=datetime.now) # 创建时间
|
|
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) # 更新时间
|
|
is_active = Column(Boolean, default=True) # 是否激活
|
|
|
|
def __repr__(self):
|
|
return f"<User(id={self.id}, name='{self.name}', age={self.age})>"
|
|
|
|
def get_feature_array(self):
|
|
"""从二进制恢复为numpy数组"""
|
|
return np.frombuffer(self.feature_vector, dtype=np.float32)
|
|
|
|
def set_feature_array(self, feature_array):
|
|
"""将numpy数组转换为二进制存储"""
|
|
if isinstance(feature_array, list):
|
|
feature_array = np.array(feature_array, dtype=np.float32)
|
|
self.feature_vector = feature_array.tobytes()
|
|
self.feature_dim = len(feature_array)
|
|
|
|
def to_dict(self):
|
|
"""转换为字典"""
|
|
return {
|
|
'id': self.id,
|
|
'name': self.name,
|
|
'age': self.age,
|
|
'feature_dim': self.feature_dim,
|
|
'photo_filename': self.photo_filename,
|
|
'created_at': self.created_at.isoformat() if self.created_at else None,
|
|
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
|
|
'is_active': self.is_active
|
|
}
|
|
|
|
def to_dict_with_feature(self):
|
|
"""转换为字典(包含特征向量)"""
|
|
data = self.to_dict()
|
|
data['feature_vector'] = self.get_feature_array().tolist()
|
|
return data
|
|
|
|
|
|
class RecognitionLog(Base):
|
|
"""识别记录表 - 存储每次识别的日志"""
|
|
__tablename__ = 'recognition_logs'
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
user_id = Column(Integer, nullable=True) # 识别到的用户ID(未识别则为空)
|
|
user_name = Column(String(100), nullable=True) # 识别到的用户姓名
|
|
device_id = Column(String(100), nullable=True) # 设备ID
|
|
similarity = Column(Float, nullable=True) # 相似度
|
|
threshold = Column(Float, default=0.7) # 使用的阈值
|
|
is_recognized = Column(Boolean, default=False) # 是否识别成功
|
|
recognized_at = Column(DateTime, default=datetime.now) # 识别时间
|
|
photo_filename = Column(String(255), nullable=True) # 识别时的照片
|
|
|
|
def __repr__(self):
|
|
return f"<RecognitionLog(id={self.id}, user_name='{self.user_name}', similarity={self.similarity})>"
|
|
|
|
def to_dict(self):
|
|
"""转换为字典"""
|
|
return {
|
|
'id': self.id,
|
|
'user_id': self.user_id,
|
|
'user_name': self.user_name,
|
|
'device_id': self.device_id,
|
|
'similarity': self.similarity,
|
|
'threshold': self.threshold,
|
|
'is_recognized': self.is_recognized,
|
|
'recognized_at': self.recognized_at.isoformat() if self.recognized_at else None,
|
|
'photo_filename': self.photo_filename
|
|
}
|
|
|
|
|
|
# 数据库配置
|
|
DATABASE_URL = "sqlite:///./face_recognition.db"
|
|
|
|
# 创建引擎
|
|
engine = create_engine(
|
|
DATABASE_URL,
|
|
connect_args={"check_same_thread": False}, # SQLite特定配置
|
|
echo=False # 设置为True可以看到SQL语句
|
|
)
|
|
|
|
# 创建会话工厂
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
|
|
def init_database():
|
|
"""初始化数据库,创建所有表"""
|
|
Base.metadata.create_all(bind=engine)
|
|
print("✅ 数据库初始化完成")
|
|
|
|
|
|
def get_db():
|
|
"""获取数据库会话"""
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 测试: 创建数据库
|
|
init_database()
|
|
|
|
# 测试: 创建一个用户
|
|
db = SessionLocal()
|
|
test_feature = np.random.randn(1024).astype(np.float32)
|
|
test_feature = test_feature / np.linalg.norm(test_feature) # 归一化
|
|
|
|
test_user = User(
|
|
name="测试用户",
|
|
age=25,
|
|
)
|
|
test_user.set_feature_array(test_feature)
|
|
|
|
try:
|
|
db.add(test_user)
|
|
db.commit()
|
|
print(f"✅ 创建测试用户成功: {test_user}")
|
|
|
|
# 测试读取
|
|
user = db.query(User).filter(User.name == "测试用户").first()
|
|
if user:
|
|
print(f"✅ 读取用户成功: {user}")
|
|
print(f" 特征维度: {user.feature_dim}")
|
|
feature_array = user.get_feature_array()
|
|
print(f" 特征数组形状: {feature_array.shape}")
|
|
except Exception as e:
|
|
print(f"❌ 错误: {e}")
|
|
db.rollback()
|
|
finally:
|
|
db.close()
|