""" Redis客户端工具模块 提供类型完整的Redis客户端包装器,支持内容验证和序列化 """ import json from typing import Any, Optional, Type, TypeVar, Union import redis.asyncio as aioredis from pydantic import BaseModel, ValidationError from loveace.config.logger import logger from loveace.database.creator import db_manager T = TypeVar("T", bound=BaseModel) class RedisClient: """类型完整的Redis客户端包装器 提供带有数据验证和序列化的Redis操作接口 Example: >>> client = RedisClient(redis_instance) >>> # 存储对象 >>> await client.set_object("user:1", user_data, User) >>> # 获取对象 >>> user = await client.get_object("user:1", User) """ def __init__(self, redis_client: aioredis.Redis): """初始化Redis客户端包装器 Args: redis_client: aioredis.Redis 实例 """ self.client = redis_client async def set_object( self, key: str, value: Union[BaseModel, dict, Any], model_class: Optional[Type[T]] = None, expire: Optional[int] = None, ) -> bool: """设置对象到Redis,支持自动验证和序列化 Args: key: Redis键 value: 要存储的值(BaseModel、dict或其他可序列化对象) model_class: 对象模型类,用于验证。如果提供,会先验证value expire: 过期时间(秒),None表示不设置过期时间 Returns: 是否成功设置 Raises: ValidationError: 当model_class验证失败时 TypeError: 当value无法序列化时 """ try: # 验证数据 if model_class is not None: if isinstance(value, model_class): validated_value = value else: validated_value = model_class( **value if isinstance(value, dict) else value.dict() ) else: validated_value = value # 序列化 if isinstance(validated_value, BaseModel): data = validated_value.model_dump_json() elif isinstance(validated_value, dict): data = json.dumps(validated_value, ensure_ascii=False) else: data = json.dumps(validated_value, ensure_ascii=False) # 存储到Redis if expire: await self.client.setex(key, expire, data) else: await self.client.set(key, data) logger.debug(f"成功存储Redis键: {key}") return True except ValidationError as e: logger.error(f"Redis对象验证失败 {key}: {e}") raise except Exception as e: logger.error(f"Redis存储失败 {key}: {e}") raise async def get_object( self, key: str, model_class: Type[T], ) -> Optional[T]: """从Redis获取对象,并通过指定的模型类进行验证 Args: key: Redis键 model_class: 对象模型类,用于反序列化和验证 Returns: 反序列化并验证后的对象,如果键不存在则返回None Raises: ValidationError: 当数据验证失败时 """ try: data = await self.client.get(key) if data is None: logger.debug(f"Redis键不存在: {key}") return None # 反序列化 if isinstance(data, bytes): data = data.decode("utf-8") parsed_data = json.loads(data) # 验证并创建模型实例 validated_value = model_class(**parsed_data) logger.debug(f"成功获取并验证Redis键: {key}") return validated_value except ValidationError as e: logger.error(f"Redis对象验证失败 {key}: {e}") raise except json.JSONDecodeError as e: logger.error(f"Redis JSON解析失败 {key}: {e}") raise except Exception as e: logger.error(f"Redis获取失败 {key}: {e}") raise async def get_object_safe( self, key: str, model_class: Type[T], default: Optional[T] = None, ) -> Optional[T]: """安全地从Redis获取对象,验证失败时返回默认值 Args: key: Redis键 model_class: 对象模型类,用于反序列化和验证 default: 验证失败时的默认返回值 Returns: 反序列化并验证后的对象,验证失败返回default """ try: return await self.get_object(key, model_class) except (ValidationError, json.JSONDecodeError, Exception) as e: logger.warning(f"Redis安全获取失败,返回默认值 {key}: {e}") return default async def set_raw( self, key: str, value: Union[str, bytes], expire: Optional[int] = None, ) -> bool: """设置原始字符串值到Redis Args: key: Redis键 value: 要存储的值(字符串或字节) expire: 过期时间(秒) Returns: 是否成功设置 """ try: if expire: await self.client.setex(key, expire, value) else: await self.client.set(key, value) logger.debug(f"成功存储原始值到Redis: {key}") return True except Exception as e: logger.error(f"Redis原始值存储失败 {key}: {e}") raise async def get_raw(self, key: str) -> Optional[Union[str, bytes]]: """获取原始字符串值 Args: key: Redis键 Returns: 存储的值,如果键不存在则返回None """ try: data = await self.client.get(key) if data is None: logger.debug(f"Redis键不存在: {key}") return None logger.debug(f"成功获取原始值: {key}") return data except Exception as e: logger.error(f"Redis获取失败 {key}: {e}") raise async def delete(self, key: str) -> int: """删除Redis键 Args: key: 要删除的键 Returns: 删除的键数量 """ try: result = await self.client.delete(key) logger.debug(f"成功删除Redis键: {key}") return result except Exception as e: logger.error(f"Redis删除失败 {key}: {e}") raise async def exists(self, key: str) -> bool: """检查键是否存在 Args: key: 要检查的键 Returns: 键是否存在 """ try: return await self.client.exists(key) > 0 except Exception as e: logger.error(f"Redis检查失败 {key}: {e}") raise async def expire(self, key: str, seconds: int) -> bool: """设置键的过期时间 Args: key: Redis键 seconds: 过期时间(秒) Returns: 是否成功设置 """ try: result = await self.client.expire(key, seconds) logger.debug(f"成功设置Redis键过期时间: {key}, {seconds}秒") return result > 0 except Exception as e: logger.error(f"Redis设置过期失败 {key}: {e}") raise async def ttl(self, key: str) -> int: """获取键的剩余生存时间 Args: key: Redis键 Returns: 剩余生存时间(秒),-1表示永不过期,-2表示键不存在 """ try: return await self.client.ttl(key) except Exception as e: logger.error(f"Redis获取TTL失败 {key}: {e}") raise async def increment( self, key: str, amount: int = 1, ) -> int: """增加键的值 Args: key: Redis键 amount: 增加的数量 Returns: 增加后的值 """ try: result = await self.client.incrby(key, amount) logger.debug(f"成功增加Redis键: {key}") return result except Exception as e: logger.error(f"Redis增加失败 {key}: {e}") raise async def decrement( self, key: str, amount: int = 1, ) -> int: """减少键的值 Args: key: Redis键 amount: 减少的数量 Returns: 减少后的值 """ try: result = await self.client.decrby(key, amount) logger.debug(f"成功减少Redis键: {key}") return result except Exception as e: logger.error(f"Redis减少失败 {key}: {e}") raise async def list_push( self, key: str, values: list[Union[BaseModel, dict, str]], model_class: Optional[Type[T]] = None, ) -> int: """向列表推入元素 Args: key: Redis键 values: 要推入的值列表 model_class: 对象模型类,用于验证每个值 Returns: 推入后列表的长度 """ try: serialized_values = [] for value in values: if model_class is not None: if isinstance(value, model_class): validated_value = value else: if isinstance(value, dict): validated_value = model_class(**value) else: validated_value = value else: validated_value = value if isinstance(validated_value, BaseModel): serialized_values.append(validated_value.model_dump_json()) elif isinstance(validated_value, dict): serialized_values.append( json.dumps(validated_value, ensure_ascii=False) ) else: serialized_values.append(str(validated_value)) result: int = await self.client.rpush(key, *serialized_values) # type: ignore logger.debug(f"成功推入Redis列表: {key}") return result except Exception as e: logger.error(f"Redis列表推入失败 {key}: {e}") raise async def list_range( self, key: str, start: int = 0, end: int = -1, model_class: Optional[Type[T]] = None, ) -> list[Union[T, str]]: """获取列表范围内的元素 Args: key: Redis键 start: 开始索引 end: 结束索引 model_class: 对象模型类,用于反序列化。如果为None则返回原始字符串 Returns: 列表中指定范围的元素 """ try: data: list[Any] = await self.client.lrange(key, start, end) # type: ignore if model_class is None: return data result = [] for item in data: if isinstance(item, bytes): item = item.decode("utf-8") try: parsed = json.loads(item) result.append(model_class(**parsed)) except (json.JSONDecodeError, ValidationError): result.append(item) return result except Exception as e: logger.error(f"Redis列表获取失败 {key}: {e}") raise async def hash_set( self, key: str, mapping: dict[str, Union[BaseModel, dict, str, int]], model_class: Optional[Type[T]] = None, ) -> int: """设置哈希表字段 Args: key: Redis键 mapping: 字段值映射 model_class: 对象模型类,用于验证值 Returns: 新添加的字段数 """ try: serialized_mapping = {} for field, value in mapping.items(): if model_class is not None and not isinstance(value, (str, int, float)): if isinstance(value, dict): validated_value = model_class(**value) else: validated_value = value if isinstance(validated_value, BaseModel): serialized_mapping[field] = validated_value.model_dump_json() else: serialized_mapping[field] = str(value) else: serialized_mapping[field] = str(value) result: int = await self.client.hset(key, mapping=serialized_mapping) # type: ignore logger.debug(f"成功设置Redis哈希表: {key}") return result except Exception as e: logger.error(f"Redis哈希表设置失败 {key}: {e}") raise async def hash_get( self, key: str, field: str, model_class: Optional[Type[T]] = None, ) -> Optional[Union[T, str]]: """获取哈希表字段值 Args: key: Redis键 field: 字段名 model_class: 对象模型类,用于反序列化 Returns: 字段值,如果不存在则返回None """ try: data: Optional[Any] = await self.client.hget(key, field) # type: ignore if data is None: return None if isinstance(data, bytes): data = data.decode("utf-8") if model_class is None: return data try: parsed = json.loads(data) return model_class(**parsed) except (json.JSONDecodeError, ValidationError): return data except Exception as e: logger.error(f"Redis哈希表获取失败 {key}:{field}: {e}") raise async def hash_get_all( self, key: str, model_class: Optional[Type[T]] = None, ) -> dict[str, Union[T, str]]: """获取所有哈希表字段 Args: key: Redis键 model_class: 对象模型类,用于反序列化值 Returns: 哈希表中的所有字段值 """ try: data: dict[Any, Any] = await self.client.hgetall(key) # type: ignore if model_class is None: return data result = {} for field, value in data.items(): if isinstance(value, bytes): value = value.decode("utf-8") try: parsed = json.loads(value) result[field] = model_class(**parsed) except (json.JSONDecodeError, ValidationError): result[field] = value return result except Exception as e: logger.error(f"Redis哈希表全量获取失败 {key}: {e}") raise async def hash_delete( self, key: str, *fields: str, ) -> int: """删除哈希表字段 Args: key: Redis键 fields: 要删除的字段名 Returns: 删除的字段数 """ try: result: int = await self.client.hdel(key, *fields) # type: ignore logger.debug(f"成功删除Redis哈希表字段: {key}") return result except Exception as e: logger.error(f"Redis哈希表删除失败 {key}: {e}") raise async def get_redis_client() -> RedisClient: """获取全局Redis客户端实例 Returns: aioredis.Redis 实例 """ redis_instance = await db_manager.get_redis_client() redis_client = RedisClient(redis_instance) return redis_client