⚒️ 重大重构 LoveACE V2
引入了 mongodb 对数据库进行了一定程度的数据加密 性能改善 代码简化 统一错误模型和响应 使用 apifox 作为文档
This commit is contained in:
547
loveace/utils/redis_client.py
Normal file
547
loveace/utils/redis_client.py
Normal file
@@ -0,0 +1,547 @@
|
||||
"""
|
||||
Redis客户端工具模块
|
||||
|
||||
提供类型完整的Redis客户端包装器,支持内容验证和序列化
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Optional, Type, TypeVar, Union
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from loveace.config.logger import logger
|
||||
from loveace.database.creator import db_manager
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class RedisClient:
|
||||
"""类型完整的Redis客户端包装器
|
||||
|
||||
提供带有数据验证和序列化的Redis操作接口
|
||||
|
||||
Example:
|
||||
>>> client = RedisClient(redis_instance)
|
||||
>>> # 存储对象
|
||||
>>> await client.set_object("user:1", user_data, User)
|
||||
>>> # 获取对象
|
||||
>>> user = await client.get_object("user:1", User)
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: aioredis.Redis):
|
||||
"""初始化Redis客户端包装器
|
||||
|
||||
Args:
|
||||
redis_client: aioredis.Redis 实例
|
||||
"""
|
||||
self.client = redis_client
|
||||
|
||||
async def set_object(
|
||||
self,
|
||||
key: str,
|
||||
value: Union[BaseModel, dict, Any],
|
||||
model_class: Optional[Type[T]] = None,
|
||||
expire: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""设置对象到Redis,支持自动验证和序列化
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
value: 要存储的值(BaseModel、dict或其他可序列化对象)
|
||||
model_class: 对象模型类,用于验证。如果提供,会先验证value
|
||||
expire: 过期时间(秒),None表示不设置过期时间
|
||||
|
||||
Returns:
|
||||
是否成功设置
|
||||
|
||||
Raises:
|
||||
ValidationError: 当model_class验证失败时
|
||||
TypeError: 当value无法序列化时
|
||||
"""
|
||||
try:
|
||||
# 验证数据
|
||||
if model_class is not None:
|
||||
if isinstance(value, model_class):
|
||||
validated_value = value
|
||||
else:
|
||||
validated_value = model_class(
|
||||
**value if isinstance(value, dict) else value.dict()
|
||||
)
|
||||
else:
|
||||
validated_value = value
|
||||
|
||||
# 序列化
|
||||
if isinstance(validated_value, BaseModel):
|
||||
data = validated_value.model_dump_json()
|
||||
elif isinstance(validated_value, dict):
|
||||
data = json.dumps(validated_value, ensure_ascii=False)
|
||||
else:
|
||||
data = json.dumps(validated_value, ensure_ascii=False)
|
||||
|
||||
# 存储到Redis
|
||||
if expire:
|
||||
await self.client.setex(key, expire, data)
|
||||
else:
|
||||
await self.client.set(key, data)
|
||||
|
||||
logger.debug(f"成功存储Redis键: {key}")
|
||||
return True
|
||||
|
||||
except ValidationError as e:
|
||||
logger.error(f"Redis对象验证失败 {key}: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Redis存储失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def get_object(
|
||||
self,
|
||||
key: str,
|
||||
model_class: Type[T],
|
||||
) -> Optional[T]:
|
||||
"""从Redis获取对象,并通过指定的模型类进行验证
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
model_class: 对象模型类,用于反序列化和验证
|
||||
|
||||
Returns:
|
||||
反序列化并验证后的对象,如果键不存在则返回None
|
||||
|
||||
Raises:
|
||||
ValidationError: 当数据验证失败时
|
||||
"""
|
||||
try:
|
||||
data = await self.client.get(key)
|
||||
|
||||
if data is None:
|
||||
logger.debug(f"Redis键不存在: {key}")
|
||||
return None
|
||||
|
||||
# 反序列化
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
|
||||
parsed_data = json.loads(data)
|
||||
|
||||
# 验证并创建模型实例
|
||||
validated_value = model_class(**parsed_data)
|
||||
logger.debug(f"成功获取并验证Redis键: {key}")
|
||||
return validated_value
|
||||
|
||||
except ValidationError as e:
|
||||
logger.error(f"Redis对象验证失败 {key}: {e}")
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Redis JSON解析失败 {key}: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Redis获取失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def get_object_safe(
|
||||
self,
|
||||
key: str,
|
||||
model_class: Type[T],
|
||||
default: Optional[T] = None,
|
||||
) -> Optional[T]:
|
||||
"""安全地从Redis获取对象,验证失败时返回默认值
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
model_class: 对象模型类,用于反序列化和验证
|
||||
default: 验证失败时的默认返回值
|
||||
|
||||
Returns:
|
||||
反序列化并验证后的对象,验证失败返回default
|
||||
"""
|
||||
try:
|
||||
return await self.get_object(key, model_class)
|
||||
except (ValidationError, json.JSONDecodeError, Exception) as e:
|
||||
logger.warning(f"Redis安全获取失败,返回默认值 {key}: {e}")
|
||||
return default
|
||||
|
||||
async def set_raw(
|
||||
self,
|
||||
key: str,
|
||||
value: Union[str, bytes],
|
||||
expire: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""设置原始字符串值到Redis
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
value: 要存储的值(字符串或字节)
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功设置
|
||||
"""
|
||||
try:
|
||||
if expire:
|
||||
await self.client.setex(key, expire, value)
|
||||
else:
|
||||
await self.client.set(key, value)
|
||||
logger.debug(f"成功存储原始值到Redis: {key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Redis原始值存储失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def get_raw(self, key: str) -> Optional[Union[str, bytes]]:
|
||||
"""获取原始字符串值
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
|
||||
Returns:
|
||||
存储的值,如果键不存在则返回None
|
||||
"""
|
||||
try:
|
||||
data = await self.client.get(key)
|
||||
if data is None:
|
||||
logger.debug(f"Redis键不存在: {key}")
|
||||
return None
|
||||
logger.debug(f"成功获取原始值: {key}")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"Redis获取失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
"""删除Redis键
|
||||
|
||||
Args:
|
||||
key: 要删除的键
|
||||
|
||||
Returns:
|
||||
删除的键数量
|
||||
"""
|
||||
try:
|
||||
result = await self.client.delete(key)
|
||||
logger.debug(f"成功删除Redis键: {key}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Redis删除失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""检查键是否存在
|
||||
|
||||
Args:
|
||||
key: 要检查的键
|
||||
|
||||
Returns:
|
||||
键是否存在
|
||||
"""
|
||||
try:
|
||||
return await self.client.exists(key) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Redis检查失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def expire(self, key: str, seconds: int) -> bool:
|
||||
"""设置键的过期时间
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
seconds: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功设置
|
||||
"""
|
||||
try:
|
||||
result = await self.client.expire(key, seconds)
|
||||
logger.debug(f"成功设置Redis键过期时间: {key}, {seconds}秒")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Redis设置过期失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def ttl(self, key: str) -> int:
|
||||
"""获取键的剩余生存时间
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
|
||||
Returns:
|
||||
剩余生存时间(秒),-1表示永不过期,-2表示键不存在
|
||||
"""
|
||||
try:
|
||||
return await self.client.ttl(key)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis获取TTL失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def increment(
|
||||
self,
|
||||
key: str,
|
||||
amount: int = 1,
|
||||
) -> int:
|
||||
"""增加键的值
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
amount: 增加的数量
|
||||
|
||||
Returns:
|
||||
增加后的值
|
||||
"""
|
||||
try:
|
||||
result = await self.client.incrby(key, amount)
|
||||
logger.debug(f"成功增加Redis键: {key}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Redis增加失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def decrement(
|
||||
self,
|
||||
key: str,
|
||||
amount: int = 1,
|
||||
) -> int:
|
||||
"""减少键的值
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
amount: 减少的数量
|
||||
|
||||
Returns:
|
||||
减少后的值
|
||||
"""
|
||||
try:
|
||||
result = await self.client.decrby(key, amount)
|
||||
logger.debug(f"成功减少Redis键: {key}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Redis减少失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def list_push(
|
||||
self,
|
||||
key: str,
|
||||
values: list[Union[BaseModel, dict, str]],
|
||||
model_class: Optional[Type[T]] = None,
|
||||
) -> int:
|
||||
"""向列表推入元素
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
values: 要推入的值列表
|
||||
model_class: 对象模型类,用于验证每个值
|
||||
|
||||
Returns:
|
||||
推入后列表的长度
|
||||
"""
|
||||
try:
|
||||
serialized_values = []
|
||||
for value in values:
|
||||
if model_class is not None:
|
||||
if isinstance(value, model_class):
|
||||
validated_value = value
|
||||
else:
|
||||
if isinstance(value, dict):
|
||||
validated_value = model_class(**value)
|
||||
else:
|
||||
validated_value = value
|
||||
else:
|
||||
validated_value = value
|
||||
|
||||
if isinstance(validated_value, BaseModel):
|
||||
serialized_values.append(validated_value.model_dump_json())
|
||||
elif isinstance(validated_value, dict):
|
||||
serialized_values.append(
|
||||
json.dumps(validated_value, ensure_ascii=False)
|
||||
)
|
||||
else:
|
||||
serialized_values.append(str(validated_value))
|
||||
|
||||
result: int = await self.client.rpush(key, *serialized_values) # type: ignore
|
||||
logger.debug(f"成功推入Redis列表: {key}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Redis列表推入失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def list_range(
|
||||
self,
|
||||
key: str,
|
||||
start: int = 0,
|
||||
end: int = -1,
|
||||
model_class: Optional[Type[T]] = None,
|
||||
) -> list[Union[T, str]]:
|
||||
"""获取列表范围内的元素
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
start: 开始索引
|
||||
end: 结束索引
|
||||
model_class: 对象模型类,用于反序列化。如果为None则返回原始字符串
|
||||
|
||||
Returns:
|
||||
列表中指定范围的元素
|
||||
"""
|
||||
try:
|
||||
data: list[Any] = await self.client.lrange(key, start, end) # type: ignore
|
||||
|
||||
if model_class is None:
|
||||
return data
|
||||
|
||||
result = []
|
||||
for item in data:
|
||||
if isinstance(item, bytes):
|
||||
item = item.decode("utf-8")
|
||||
try:
|
||||
parsed = json.loads(item)
|
||||
result.append(model_class(**parsed))
|
||||
except (json.JSONDecodeError, ValidationError):
|
||||
result.append(item)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Redis列表获取失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def hash_set(
|
||||
self,
|
||||
key: str,
|
||||
mapping: dict[str, Union[BaseModel, dict, str, int]],
|
||||
model_class: Optional[Type[T]] = None,
|
||||
) -> int:
|
||||
"""设置哈希表字段
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
mapping: 字段值映射
|
||||
model_class: 对象模型类,用于验证值
|
||||
|
||||
Returns:
|
||||
新添加的字段数
|
||||
"""
|
||||
try:
|
||||
serialized_mapping = {}
|
||||
for field, value in mapping.items():
|
||||
if model_class is not None and not isinstance(value, (str, int, float)):
|
||||
if isinstance(value, dict):
|
||||
validated_value = model_class(**value)
|
||||
else:
|
||||
validated_value = value
|
||||
if isinstance(validated_value, BaseModel):
|
||||
serialized_mapping[field] = validated_value.model_dump_json()
|
||||
else:
|
||||
serialized_mapping[field] = str(value)
|
||||
else:
|
||||
serialized_mapping[field] = str(value)
|
||||
|
||||
result: int = await self.client.hset(key, mapping=serialized_mapping) # type: ignore
|
||||
logger.debug(f"成功设置Redis哈希表: {key}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Redis哈希表设置失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def hash_get(
|
||||
self,
|
||||
key: str,
|
||||
field: str,
|
||||
model_class: Optional[Type[T]] = None,
|
||||
) -> Optional[Union[T, str]]:
|
||||
"""获取哈希表字段值
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
field: 字段名
|
||||
model_class: 对象模型类,用于反序列化
|
||||
|
||||
Returns:
|
||||
字段值,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
data: Optional[Any] = await self.client.hget(key, field) # type: ignore
|
||||
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
|
||||
if model_class is None:
|
||||
return data
|
||||
|
||||
try:
|
||||
parsed = json.loads(data)
|
||||
return model_class(**parsed)
|
||||
except (json.JSONDecodeError, ValidationError):
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"Redis哈希表获取失败 {key}:{field}: {e}")
|
||||
raise
|
||||
|
||||
async def hash_get_all(
|
||||
self,
|
||||
key: str,
|
||||
model_class: Optional[Type[T]] = None,
|
||||
) -> dict[str, Union[T, str]]:
|
||||
"""获取所有哈希表字段
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
model_class: 对象模型类,用于反序列化值
|
||||
|
||||
Returns:
|
||||
哈希表中的所有字段值
|
||||
"""
|
||||
try:
|
||||
data: dict[Any, Any] = await self.client.hgetall(key) # type: ignore
|
||||
|
||||
if model_class is None:
|
||||
return data
|
||||
|
||||
result = {}
|
||||
for field, value in data.items():
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
result[field] = model_class(**parsed)
|
||||
except (json.JSONDecodeError, ValidationError):
|
||||
result[field] = value
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Redis哈希表全量获取失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
async def hash_delete(
|
||||
self,
|
||||
key: str,
|
||||
*fields: str,
|
||||
) -> int:
|
||||
"""删除哈希表字段
|
||||
|
||||
Args:
|
||||
key: Redis键
|
||||
fields: 要删除的字段名
|
||||
|
||||
Returns:
|
||||
删除的字段数
|
||||
"""
|
||||
try:
|
||||
result: int = await self.client.hdel(key, *fields) # type: ignore
|
||||
logger.debug(f"成功删除Redis哈希表字段: {key}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Redis哈希表删除失败 {key}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_redis_client() -> RedisClient:
|
||||
"""获取全局Redis客户端实例
|
||||
|
||||
Returns:
|
||||
aioredis.Redis 实例
|
||||
"""
|
||||
redis_instance = await db_manager.get_redis_client()
|
||||
redis_client = RedisClient(redis_instance)
|
||||
return redis_client
|
||||
107
loveace/utils/richuru_hook.py
Normal file
107
loveace/utils/richuru_hook.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from datetime import datetime
|
||||
from logging import LogRecord
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, List, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
from rich.console import Console, ConsoleRenderable
|
||||
from rich.text import Text
|
||||
from rich.theme import Theme
|
||||
from rich.traceback import Traceback
|
||||
from richuru import ExceptionHook, LoguruHandler, LoguruRichHandler, _loguru_exc_hook
|
||||
|
||||
|
||||
class HookedLoguruRichHandler(LoguruRichHandler):
|
||||
"""
|
||||
A hooked version of LoguruRichHandler to fix some issues.
|
||||
"""
|
||||
|
||||
def render(
|
||||
self,
|
||||
*,
|
||||
record: LogRecord,
|
||||
traceback: Optional[Traceback],
|
||||
message_renderable: "ConsoleRenderable",
|
||||
) -> "ConsoleRenderable":
|
||||
"""Render log for display.
|
||||
|
||||
Args:
|
||||
record (LogRecord): logging Record.
|
||||
traceback (Optional[Traceback]): Traceback instance or None for no Traceback.
|
||||
message_renderable (ConsoleRenderable): Renderable (typically Text) containing log message contents.
|
||||
|
||||
Returns:
|
||||
ConsoleRenderable: Renderable to display log.
|
||||
"""
|
||||
current_path = Path(os.getcwd())
|
||||
path = Path(record.pathname)
|
||||
try:
|
||||
path = path.relative_to(current_path)
|
||||
if sys.platform == "win32":
|
||||
path = str(path).replace("\\", "/")
|
||||
except ValueError:
|
||||
path = Path(record.pathname).name
|
||||
path = str(path)
|
||||
level = self.get_level_text(record)
|
||||
time_format = None if self.formatter is None else self.formatter.datefmt
|
||||
log_time = datetime.fromtimestamp(record.created)
|
||||
|
||||
log_renderable = self._log_render(
|
||||
self.console,
|
||||
[message_renderable] if not traceback else [message_renderable, traceback],
|
||||
log_time=log_time,
|
||||
time_format=time_format,
|
||||
level=level,
|
||||
path=path,
|
||||
line_no=record.lineno,
|
||||
link_path=record.pathname if self.enable_link_path else None,
|
||||
)
|
||||
return log_renderable
|
||||
|
||||
|
||||
def install(
|
||||
rich_console: Optional[Console] = None,
|
||||
exc_hook: Optional[ExceptionHook] = _loguru_exc_hook,
|
||||
rich_traceback: bool = True,
|
||||
tb_ctx_lines: int = 3,
|
||||
tb_theme: Optional[str] = None,
|
||||
tb_suppress: Iterable[Union[str, types.ModuleType]] = (),
|
||||
time_format: Union[str, Callable[[datetime], Text]] = "[%x %X]",
|
||||
keywords: Optional[List[str]] = None,
|
||||
level: Union[int, str] = 20,
|
||||
) -> None:
|
||||
"""Install Rich logging and Loguru exception hook"""
|
||||
logging.basicConfig(handlers=[LoguruHandler()], level=0)
|
||||
logger.configure(
|
||||
handlers=[
|
||||
{
|
||||
"sink": HookedLoguruRichHandler(
|
||||
console=rich_console
|
||||
or Console(
|
||||
theme=Theme(
|
||||
{
|
||||
"logging.level.success": "green",
|
||||
"logging.level.trace": "bright_black",
|
||||
}
|
||||
)
|
||||
),
|
||||
rich_tracebacks=rich_traceback,
|
||||
tracebacks_show_locals=True,
|
||||
tracebacks_suppress=tb_suppress,
|
||||
tracebacks_extra_lines=tb_ctx_lines,
|
||||
tracebacks_theme=tb_theme,
|
||||
show_time=False,
|
||||
log_time_format=time_format,
|
||||
keywords=keywords,
|
||||
),
|
||||
"format": (lambda _: "{message}") if rich_traceback else "{message}",
|
||||
"level": level,
|
||||
}
|
||||
]
|
||||
)
|
||||
if exc_hook is not None:
|
||||
sys.excepthook = exc_hook
|
||||
332
loveace/utils/rsa.py
Normal file
332
loveace/utils/rsa.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import base64
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCMSIV
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from loveace.config.manager import config_manager
|
||||
|
||||
console = Console()
|
||||
|
||||
rsa_context: ContextVar[Dict[str, "RSAUtils"]] = ContextVar("rsa_context")
|
||||
|
||||
|
||||
class RSAUtils:
|
||||
"""RSA 工具类,支持 AES-GCM-SIV 加密的密钥保护"""
|
||||
|
||||
private_key_path: str
|
||||
private_key: RSAPrivateKey
|
||||
public_key: RSAPublicKey
|
||||
|
||||
def __init__(self, private_key_path: str | None = None):
|
||||
"""初始化 RSAUtils 类
|
||||
|
||||
Args:
|
||||
private_key_path (str): 私钥文件路径
|
||||
"""
|
||||
settings = config_manager.get_settings()
|
||||
self.private_key_path = str(
|
||||
Path(settings.app.rsa_protect_key_path).joinpath(
|
||||
Path(
|
||||
private_key_path
|
||||
or config_manager.get_settings().app.rsa_private_key_path
|
||||
).name
|
||||
)
|
||||
)
|
||||
# 转换路径扩展名为 .hex
|
||||
self.private_key_path = str(self.private_key_path).replace(".pem", ".hex")
|
||||
self.load_keys()
|
||||
|
||||
def _derive_key_from_password(
|
||||
self, password: str, salt: bytes | None = None
|
||||
) -> tuple[bytes, bytes]:
|
||||
"""从密码派生 AES 密钥
|
||||
|
||||
Args:
|
||||
password (str): 用户输入的密码
|
||||
salt (bytes): 盐值,如果为 None 则生成新的
|
||||
|
||||
Returns:
|
||||
tuple[bytes, bytes]: (派生密钥, 盐值)
|
||||
"""
|
||||
if salt is None:
|
||||
salt = os.urandom(16)
|
||||
|
||||
kdf_obj = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=16, # AES-128 需要 16 字节密钥
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
key = kdf_obj.derive(password.encode("utf-8"))
|
||||
return key, salt
|
||||
|
||||
def load_keys(self):
|
||||
"""加载密钥对(从加密的 AES 文件中)"""
|
||||
path = Path(self.private_key_path)
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold cyan]正在操作密钥文件[/bold cyan]\n"
|
||||
f"[cyan]文件路径:{self.private_key_path}[/cyan]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
if not path.exists():
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold yellow]RSA 密钥对不存在,将为您生成新的密钥对[/bold yellow]",
|
||||
title="[bold blue]密钥生成[/bold blue]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
self.generate_keys()
|
||||
else:
|
||||
self._load_encrypted_key()
|
||||
|
||||
def _load_encrypted_key(self):
|
||||
"""从加密的 .hex 文件加载密钥"""
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold cyan]检测到本地 RSA 私钥文件[/bold cyan]\n"
|
||||
f"[cyan]文件路径:{self.private_key_path}[/cyan]",
|
||||
title="[bold blue]密钥加载[/bold blue]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
|
||||
console.print(
|
||||
"[bold yellow]该密钥文件受密码保护,需要您输入密码来解密[/bold yellow]"
|
||||
)
|
||||
password = Prompt.ask(
|
||||
"[bold]请输入 RSA 私钥密码[/bold]", password=True, console=console
|
||||
)
|
||||
|
||||
with open(self.private_key_path, "rb") as key_file:
|
||||
encrypted_data = key_file.read()
|
||||
|
||||
# 解析加密数据:salt(16) + nonce(12) + ciphertext
|
||||
salt = encrypted_data[:16]
|
||||
nonce = encrypted_data[16:28]
|
||||
ciphertext = encrypted_data[28:]
|
||||
|
||||
# 派生密钥
|
||||
key, _ = self._derive_key_from_password(password, salt)
|
||||
|
||||
# 使用 AES-GCM-SIV 解密
|
||||
try:
|
||||
aesgcmsiv = AESGCMSIV(key)
|
||||
plaintext = aesgcmsiv.decrypt(nonce, ciphertext, None)
|
||||
console.print("[bold green]✓ 私钥密码验证成功[/bold green]")
|
||||
except Exception as e:
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold red]✗ 私钥密码错误或密钥文件已损坏[/bold red]",
|
||||
title="[bold red]错误[/bold red]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
raise ValueError("Invalid password or corrupted key file") from e
|
||||
|
||||
# 加载 PEM 格式的私钥
|
||||
try:
|
||||
pk = serialization.load_pem_private_key(
|
||||
plaintext, password=None, backend=default_backend()
|
||||
)
|
||||
if isinstance(pk, RSAPrivateKey):
|
||||
self.private_key = pk
|
||||
else:
|
||||
raise ValueError("Loaded key is not an RSA private key")
|
||||
except Exception:
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold red]✗ 密钥格式错误[/bold red]",
|
||||
title="[bold red]错误[/bold red]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
self.public_key = self.private_key.public_key()
|
||||
|
||||
def generate_keys(self, key_size: int = 2048):
|
||||
"""生成 RSA 密钥对并使用 AES 加密保存到文件
|
||||
|
||||
Args:
|
||||
key_size (int): 密钥大小,默认2048位
|
||||
"""
|
||||
path = Path(self.private_key_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 提示用户设置密码
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold cyan]请设置 RSA 私钥密码(用于保护密钥文件)[/bold cyan]",
|
||||
title="[bold blue]密钥保护[/bold blue]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
password = Prompt.ask("[bold]请输入密码[/bold]", password=True, console=console)
|
||||
password_confirm = Prompt.ask(
|
||||
"[bold]请确认密码[/bold]", password=True, console=console
|
||||
)
|
||||
|
||||
if password != password_confirm:
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold red]✗ 两次输入的密码不一致[/bold red]",
|
||||
title="[bold red]错误[/bold red]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
raise ValueError("Passwords do not match")
|
||||
|
||||
# 生成 RSA 密钥对
|
||||
console.print("[bold cyan]正在生成 RSA 密钥对...[/bold cyan]")
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=key_size, backend=default_backend()
|
||||
)
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# 将私钥序列化为 PEM 格式
|
||||
pem_private = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
pem_public = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
# 使用 AES-GCM-SIV 加密私钥
|
||||
key, salt = self._derive_key_from_password(password)
|
||||
aesgcmsiv = AESGCMSIV(key)
|
||||
nonce = os.urandom(12)
|
||||
ciphertext = aesgcmsiv.encrypt(nonce, pem_private, None)
|
||||
|
||||
# 保存加密的私钥:salt + nonce + ciphertext
|
||||
with open(self.private_key_path, "wb") as private_file:
|
||||
private_file.write(salt + nonce + ciphertext)
|
||||
|
||||
# 保存公钥(不加密)
|
||||
public_key_path = self.private_key_path.replace(".hex", "_public.pem")
|
||||
with open(public_key_path, "wb") as public_file:
|
||||
public_file.write(pem_public)
|
||||
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold green]✓ RSA 密钥对生成成功[/bold green]\n"
|
||||
f"[cyan]私钥路径:[/cyan]{self.private_key_path}\n"
|
||||
f"[cyan]公钥路径:[/cyan]{public_key_path}",
|
||||
title="[bold blue]完成[/bold blue]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
|
||||
def encrypt(self, plaintext: str) -> str:
|
||||
"""使用公钥加密数据
|
||||
|
||||
Args:
|
||||
plaintext (str): 明文字符串
|
||||
|
||||
Returns:
|
||||
str: Base64 编码的密文字符串
|
||||
"""
|
||||
ciphertext = self.public_key.encrypt(
|
||||
plaintext.encode("utf-8"),
|
||||
padding.PKCS1v15(),
|
||||
)
|
||||
return base64.b64encode(ciphertext).decode("utf-8")
|
||||
|
||||
def decrypt(self, b64_ciphertext: str) -> str:
|
||||
"""使用私钥解密数据
|
||||
|
||||
Args:
|
||||
b64_ciphertext (str): Base64 编码的密文字符串
|
||||
|
||||
Returns:
|
||||
str: 解密后的明文字符串
|
||||
"""
|
||||
ciphertext = base64.b64decode(b64_ciphertext)
|
||||
plaintext = self.private_key.decrypt(
|
||||
ciphertext,
|
||||
padding.PKCS1v15(),
|
||||
)
|
||||
return plaintext.decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def encrypt_file_with_aes(
|
||||
plaintext: bytes, password: str | None = None
|
||||
) -> tuple[bytes, str]:
|
||||
"""使用 AES-GCM-SIV 和密码加密数据
|
||||
|
||||
Args:
|
||||
plaintext (bytes): 明文数据
|
||||
password (str): 密码,如果为 None 则生成随机密钥
|
||||
|
||||
Returns:
|
||||
tuple[bytes, str]: (加密数据, 密钥的十六进制字符串)
|
||||
"""
|
||||
if password is None:
|
||||
# 生成随机密钥
|
||||
key = AESGCMSIV.generate_key(bit_length=128)
|
||||
aesgcmsiv = AESGCMSIV(key)
|
||||
nonce = os.urandom(12)
|
||||
ciphertext = aesgcmsiv.encrypt(nonce, plaintext, None)
|
||||
encrypted_data = key + nonce + ciphertext
|
||||
else:
|
||||
# 从密码派生密钥
|
||||
salt = os.urandom(16)
|
||||
kdf_obj = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=16,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
key = kdf_obj.derive(password.encode("utf-8"))
|
||||
aesgcmsiv = AESGCMSIV(key)
|
||||
nonce = os.urandom(12)
|
||||
ciphertext = aesgcmsiv.encrypt(nonce, plaintext, None)
|
||||
encrypted_data = salt + nonce + ciphertext
|
||||
|
||||
key_hex = key.hex()
|
||||
return encrypted_data, key_hex
|
||||
|
||||
@staticmethod
|
||||
def get_or_create_rsa_utils(private_key_path: str | None = None) -> "RSAUtils":
|
||||
"""
|
||||
获取或创建 RSAUtils 实例
|
||||
Args:
|
||||
private_key_path (str | None): 私钥文件路径,如果为 None 则使用配置中的默认路径
|
||||
"""
|
||||
private_key_path = (
|
||||
private_key_path or config_manager.get_settings().app.rsa_private_key_path
|
||||
)
|
||||
try:
|
||||
rsa_utils_dict = rsa_context.get()
|
||||
if private_key_path in rsa_utils_dict:
|
||||
return rsa_utils_dict[private_key_path]
|
||||
else:
|
||||
rsa_utils = RSAUtils(private_key_path)
|
||||
rsa_utils_dict[private_key_path] = rsa_utils
|
||||
rsa_context.set(rsa_utils_dict)
|
||||
return rsa_utils
|
||||
except LookupError:
|
||||
rsa_utils = RSAUtils(private_key_path)
|
||||
rsa_utils_dict = {private_key_path: rsa_utils}
|
||||
rsa_context.set(rsa_utils_dict)
|
||||
return rsa_utils
|
||||
Reference in New Issue
Block a user