90 lines
2.6 KiB
Python
90 lines
2.6 KiB
Python
|
|
import json
|
|||
|
|
import uuid
|
|||
|
|
from fastapi import Depends, HTTPException
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
from database.creator import get_db_session
|
|||
|
|
from database.user import User, AuthME
|
|||
|
|
from sqlalchemy import select, desc
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
from loguru import logger
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AuthmeRequest(BaseModel):
|
|||
|
|
token: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AuthmeResponse(BaseModel):
|
|||
|
|
code: int
|
|||
|
|
message: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def fetch_user_by_token(
|
|||
|
|
AuthmeRequest: AuthmeRequest,
|
|||
|
|
asyncsession: AsyncSession = Depends(get_db_session)
|
|||
|
|
) -> User:
|
|||
|
|
"""
|
|||
|
|
根据令牌获取用户信息
|
|||
|
|
:param AuthmeRequest: 包含token的请求对象
|
|||
|
|
:param asyncsession: 数据库会话
|
|||
|
|
:return: User
|
|||
|
|
"""
|
|||
|
|
async with asyncsession as session:
|
|||
|
|
# 根据token查找AuthME记录
|
|||
|
|
result = await session.execute(
|
|||
|
|
select(AuthME).where(AuthME.authme_token == AuthmeRequest.token)
|
|||
|
|
)
|
|||
|
|
authme = result.scalars().first()
|
|||
|
|
|
|||
|
|
if not authme:
|
|||
|
|
raise HTTPException(status_code=401, detail="无效的令牌或用户不存在")
|
|||
|
|
|
|||
|
|
# 根据userid获取用户信息
|
|||
|
|
user_result = await session.execute(
|
|||
|
|
select(User).where(User.userid == authme.userid)
|
|||
|
|
)
|
|||
|
|
user = user_result.scalars().first()
|
|||
|
|
if not user:
|
|||
|
|
raise HTTPException(status_code=401, detail="用户不存在")
|
|||
|
|
|
|||
|
|
logger.info(f"User {user.userid} fetched successfully using token.")
|
|||
|
|
return user
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def manage_user_tokens(userid: str, new_token: str, device_id: str, session: AsyncSession) -> None:
|
|||
|
|
"""
|
|||
|
|
管理用户token,每个用户最多保持5个设备会话,超出时删除最旧的2个
|
|||
|
|
:param userid: 用户ID
|
|||
|
|
:param new_token: 新的token
|
|||
|
|
:param device_id: 设备标识符
|
|||
|
|
:param session: 数据库会话
|
|||
|
|
"""
|
|||
|
|
# 检查当前用户的token数量
|
|||
|
|
result = await session.execute(
|
|||
|
|
select(AuthME)
|
|||
|
|
.where(AuthME.userid == userid)
|
|||
|
|
.order_by(desc(AuthME.create_date))
|
|||
|
|
)
|
|||
|
|
existing_tokens = result.scalars().all()
|
|||
|
|
|
|||
|
|
# 如果超过4个token(即将添加第6个),删除最旧的2个
|
|||
|
|
if len(existing_tokens) >= 5:
|
|||
|
|
# 删除最旧的2个token
|
|||
|
|
oldest_tokens = existing_tokens[-2:]
|
|||
|
|
for token_record in oldest_tokens:
|
|||
|
|
await session.delete(token_record)
|
|||
|
|
|
|||
|
|
# 添加新的token记录
|
|||
|
|
new_authme = AuthME(
|
|||
|
|
userid=userid,
|
|||
|
|
authme_token=new_token,
|
|||
|
|
device_id=device_id
|
|||
|
|
)
|
|||
|
|
session.add(new_authme)
|
|||
|
|
await session.commit()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_device_id() -> str:
|
|||
|
|
"""生成设备标识符"""
|
|||
|
|
return str(uuid.uuid4())
|