mirror of
https://github.com/Kakune55/ComiPy.git
synced 2025-09-16 04:09:41 +08:00
feat(file): 优化文件处理和缓存机制
- 重构文件处理逻辑,提高性能和可维护性 - 增加缓存机制,减少重复读取和处理 - 改进错误处理和日志记录 - 优化缩略图生成算法 - 添加性能监控和测试依赖
This commit is contained in:
236
utils/cache_manager.py
Normal file
236
utils/cache_manager.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
缓存管理器 - 用于图片和数据缓存
|
||||
提供内存缓存和磁盘缓存功能
|
||||
"""
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
import pickle
|
||||
|
||||
from utils.logger import get_logger
|
||||
import app_conf
|
||||
|
||||
logger = get_logger(__name__)
|
||||
conf = app_conf.conf()
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""缓存管理器"""
|
||||
|
||||
def __init__(self, cache_dir: Optional[str] = None, max_memory_size: int = 100):
|
||||
"""
|
||||
初始化缓存管理器
|
||||
|
||||
Args:
|
||||
cache_dir: 磁盘缓存目录
|
||||
max_memory_size: 内存缓存最大条目数
|
||||
"""
|
||||
self.cache_dir = Path(cache_dir or conf.get("file", "tmpdir")) / "cache"
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.memory_cache = {}
|
||||
self.max_memory_size = max_memory_size
|
||||
self.cache_access_times = {}
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# 缓存统计
|
||||
self.stats = {
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'memory_hits': 0,
|
||||
'disk_hits': 0
|
||||
}
|
||||
|
||||
logger.info(f"缓存管理器初始化: 目录={self.cache_dir}, 内存限制={max_memory_size}")
|
||||
|
||||
def _generate_key(self, *args) -> str:
|
||||
"""生成缓存键"""
|
||||
key_string = "_".join(str(arg) for arg in args)
|
||||
return hashlib.md5(key_string.encode('utf-8')).hexdigest()
|
||||
|
||||
def _cleanup_memory_cache(self):
|
||||
"""清理内存缓存,移除最久未访问的项目"""
|
||||
if len(self.memory_cache) <= self.max_memory_size:
|
||||
return
|
||||
|
||||
# 按访问时间排序,移除最旧的项目
|
||||
sorted_items = sorted(
|
||||
self.cache_access_times.items(),
|
||||
key=lambda x: x[1]
|
||||
)
|
||||
|
||||
# 移除最旧的20%
|
||||
remove_count = len(self.memory_cache) - self.max_memory_size + 1
|
||||
for key, _ in sorted_items[:remove_count]:
|
||||
if key in self.memory_cache:
|
||||
del self.memory_cache[key]
|
||||
del self.cache_access_times[key]
|
||||
|
||||
logger.debug(f"清理内存缓存: 移除 {remove_count} 项")
|
||||
|
||||
def get(self, key: str, default=None) -> Any:
|
||||
"""获取缓存数据"""
|
||||
with self.lock:
|
||||
current_time = time.time()
|
||||
|
||||
# 检查内存缓存
|
||||
if key in self.memory_cache:
|
||||
self.cache_access_times[key] = current_time
|
||||
self.stats['hits'] += 1
|
||||
self.stats['memory_hits'] += 1
|
||||
logger.debug(f"内存缓存命中: {key}")
|
||||
return self.memory_cache[key]
|
||||
|
||||
# 检查磁盘缓存
|
||||
cache_file = self.cache_dir / f"{key}.cache"
|
||||
if cache_file.exists():
|
||||
try:
|
||||
with open(cache_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
# 将数据加载到内存缓存
|
||||
self.memory_cache[key] = data
|
||||
self.cache_access_times[key] = current_time
|
||||
self._cleanup_memory_cache()
|
||||
|
||||
self.stats['hits'] += 1
|
||||
self.stats['disk_hits'] += 1
|
||||
logger.debug(f"磁盘缓存命中: {key}")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"读取磁盘缓存失败 {key}: {e}")
|
||||
# 删除损坏的缓存文件
|
||||
try:
|
||||
cache_file.unlink()
|
||||
except:
|
||||
pass
|
||||
|
||||
self.stats['misses'] += 1
|
||||
logger.debug(f"缓存未命中: {key}")
|
||||
return default
|
||||
|
||||
def set(self, key: str, value: Any, disk_cache: bool = True):
|
||||
"""设置缓存数据"""
|
||||
with self.lock:
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到内存缓存
|
||||
self.memory_cache[key] = value
|
||||
self.cache_access_times[key] = current_time
|
||||
self._cleanup_memory_cache()
|
||||
|
||||
# 存储到磁盘缓存
|
||||
if disk_cache:
|
||||
try:
|
||||
cache_file = self.cache_dir / f"{key}.cache"
|
||||
with open(cache_file, 'wb') as f:
|
||||
pickle.dump(value, f)
|
||||
logger.debug(f"数据已缓存: {key}")
|
||||
except Exception as e:
|
||||
logger.error(f"写入磁盘缓存失败 {key}: {e}")
|
||||
|
||||
def delete(self, key: str):
|
||||
"""删除缓存数据"""
|
||||
with self.lock:
|
||||
# 删除内存缓存
|
||||
if key in self.memory_cache:
|
||||
del self.memory_cache[key]
|
||||
del self.cache_access_times[key]
|
||||
|
||||
# 删除磁盘缓存
|
||||
cache_file = self.cache_dir / f"{key}.cache"
|
||||
if cache_file.exists():
|
||||
try:
|
||||
cache_file.unlink()
|
||||
logger.debug(f"删除缓存: {key}")
|
||||
except Exception as e:
|
||||
logger.error(f"删除磁盘缓存失败 {key}: {e}")
|
||||
|
||||
def clear(self):
|
||||
"""清空所有缓存"""
|
||||
with self.lock:
|
||||
# 清空内存缓存
|
||||
self.memory_cache.clear()
|
||||
self.cache_access_times.clear()
|
||||
|
||||
# 清空磁盘缓存
|
||||
try:
|
||||
for cache_file in self.cache_dir.glob("*.cache"):
|
||||
cache_file.unlink()
|
||||
logger.info("清空所有缓存")
|
||||
except Exception as e:
|
||||
logger.error(f"清空磁盘缓存失败: {e}")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
with self.lock:
|
||||
total_requests = self.stats['hits'] + self.stats['misses']
|
||||
hit_rate = (self.stats['hits'] / total_requests * 100) if total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
'total_requests': total_requests,
|
||||
'hits': self.stats['hits'],
|
||||
'misses': self.stats['misses'],
|
||||
'hit_rate': f"{hit_rate:.2f}%",
|
||||
'memory_hits': self.stats['memory_hits'],
|
||||
'disk_hits': self.stats['disk_hits'],
|
||||
'memory_cache_size': len(self.memory_cache),
|
||||
'disk_cache_files': len(list(self.cache_dir.glob("*.cache")))
|
||||
}
|
||||
|
||||
def cleanup_expired(self, max_age_hours: int = 24):
|
||||
"""清理过期的磁盘缓存文件"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
max_age_seconds = max_age_hours * 3600
|
||||
removed_count = 0
|
||||
|
||||
for cache_file in self.cache_dir.glob("*.cache"):
|
||||
if current_time - cache_file.stat().st_mtime > max_age_seconds:
|
||||
cache_file.unlink()
|
||||
removed_count += 1
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"清理过期缓存文件: {removed_count} 个")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理过期缓存失败: {e}")
|
||||
|
||||
|
||||
# 全局缓存管理器实例
|
||||
_cache_manager = None
|
||||
|
||||
|
||||
def get_cache_manager() -> CacheManager:
|
||||
"""获取全局缓存管理器实例"""
|
||||
global _cache_manager
|
||||
if _cache_manager is None:
|
||||
_cache_manager = CacheManager()
|
||||
return _cache_manager
|
||||
|
||||
|
||||
def cache_image(func):
|
||||
"""图片缓存装饰器"""
|
||||
def wrapper(*args, **kwargs):
|
||||
cache_manager = get_cache_manager()
|
||||
|
||||
# 生成缓存键
|
||||
cache_key = cache_manager._generate_key(func.__name__, *args, *kwargs.items())
|
||||
|
||||
# 尝试从缓存获取
|
||||
cached_result = cache_manager.get(cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# 执行函数并缓存结果
|
||||
result = func(*args, **kwargs)
|
||||
if result: # 只缓存有效结果
|
||||
cache_manager.set(cache_key, result)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
104
utils/config_validator.py
Normal file
104
utils/config_validator.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
import app_conf
|
||||
|
||||
def validate_config() -> Dict[str, Any]:
|
||||
"""
|
||||
验证配置文件的有效性
|
||||
返回验证结果字典
|
||||
"""
|
||||
conf = app_conf.conf()
|
||||
issues = []
|
||||
warnings = []
|
||||
|
||||
# 检查必需的配置项
|
||||
required_sections = {
|
||||
'server': ['port', 'host'],
|
||||
'user': ['username', 'password'],
|
||||
'database': ['path'],
|
||||
'file': ['inputdir', 'storedir', 'tmpdir'],
|
||||
'img': ['encode', 'miniSize', 'fullSize']
|
||||
}
|
||||
|
||||
for section, keys in required_sections.items():
|
||||
if not conf.has_section(section):
|
||||
issues.append(f"缺少配置节: [{section}]")
|
||||
continue
|
||||
|
||||
for key in keys:
|
||||
if not conf.has_option(section, key):
|
||||
issues.append(f"缺少配置项: [{section}].{key}")
|
||||
|
||||
# 检查安全性问题
|
||||
if conf.has_section('user'):
|
||||
username = conf.get('user', 'username', fallback='')
|
||||
password = conf.get('user', 'password', fallback='')
|
||||
|
||||
if username == 'admin' and password == 'admin':
|
||||
warnings.append("使用默认用户名和密码不安全,建议修改")
|
||||
|
||||
if len(password) < 8:
|
||||
warnings.append("密码过于简单,建议使用8位以上的复杂密码")
|
||||
|
||||
# 检查端口配置
|
||||
if conf.has_section('server'):
|
||||
try:
|
||||
port = conf.getint('server', 'port')
|
||||
if port < 1024 or port > 65535:
|
||||
warnings.append(f"端口号 {port} 可能不合适,建议使用1024-65535范围内的端口")
|
||||
except:
|
||||
issues.append("服务器端口配置无效")
|
||||
|
||||
# 检查目录权限
|
||||
if conf.has_section('file'):
|
||||
directories = ['inputdir', 'storedir', 'tmpdir']
|
||||
for dir_key in directories:
|
||||
dir_path = conf.get('file', dir_key, fallback='')
|
||||
if dir_path:
|
||||
abs_path = os.path.abspath(dir_path)
|
||||
parent_dir = os.path.dirname(abs_path)
|
||||
|
||||
if not os.path.exists(parent_dir):
|
||||
issues.append(f"父目录不存在: {parent_dir} (配置: {dir_key})")
|
||||
elif not os.access(parent_dir, os.W_OK):
|
||||
issues.append(f"没有写入权限: {parent_dir} (配置: {dir_key})")
|
||||
|
||||
# 检查数据库路径
|
||||
if conf.has_section('database'):
|
||||
db_path = conf.get('database', 'path', fallback='')
|
||||
if db_path:
|
||||
db_dir = os.path.dirname(os.path.abspath(db_path))
|
||||
if not os.path.exists(db_dir):
|
||||
issues.append(f"数据库目录不存在: {db_dir}")
|
||||
elif not os.access(db_dir, os.W_OK):
|
||||
issues.append(f"数据库目录没有写入权限: {db_dir}")
|
||||
|
||||
return {
|
||||
'valid': len(issues) == 0,
|
||||
'issues': issues,
|
||||
'warnings': warnings
|
||||
}
|
||||
|
||||
def print_validation_results(results: Dict[str, Any]):
|
||||
"""打印配置验证结果"""
|
||||
if results['valid']:
|
||||
print("✅ 配置验证通过")
|
||||
else:
|
||||
print("❌ 配置验证失败")
|
||||
print("\n严重问题:")
|
||||
for issue in results['issues']:
|
||||
print(f" • {issue}")
|
||||
|
||||
if results['warnings']:
|
||||
print("\n⚠️ 警告:")
|
||||
for warning in results['warnings']:
|
||||
print(f" • {warning}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 当直接运行此文件时,执行配置验证
|
||||
results = validate_config()
|
||||
print_validation_results(results)
|
||||
|
||||
if not results['valid']:
|
||||
sys.exit(1)
|
61
utils/db_pool.py
Normal file
61
utils/db_pool.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from queue import Queue, Empty
|
||||
import app_conf
|
||||
|
||||
class ConnectionPool:
|
||||
def __init__(self, database_path: str, max_connections: int = 10):
|
||||
self.database_path = database_path
|
||||
self.max_connections = max_connections
|
||||
self.pool = Queue(maxsize=max_connections)
|
||||
self.lock = threading.Lock()
|
||||
self._initialize_pool()
|
||||
|
||||
def _initialize_pool(self):
|
||||
"""初始化连接池"""
|
||||
for _ in range(self.max_connections):
|
||||
conn = sqlite3.connect(self.database_path, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row # 允许按列名访问
|
||||
self.pool.put(conn)
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""获取数据库连接的上下文管理器"""
|
||||
conn = None
|
||||
try:
|
||||
conn = self.pool.get(timeout=5) # 5秒超时
|
||||
yield conn
|
||||
except Empty:
|
||||
raise Exception("无法获取数据库连接:连接池已满")
|
||||
finally:
|
||||
if conn:
|
||||
self.pool.put(conn)
|
||||
|
||||
def close_all(self):
|
||||
"""关闭所有连接"""
|
||||
while not self.pool.empty():
|
||||
try:
|
||||
conn = self.pool.get_nowait()
|
||||
conn.close()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
# 全局连接池实例
|
||||
_pool = None
|
||||
_pool_lock = threading.Lock()
|
||||
|
||||
def get_pool():
|
||||
"""获取全局连接池实例"""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
with _pool_lock:
|
||||
if _pool is None:
|
||||
conf = app_conf.conf()
|
||||
database_path = conf.get("database", "path")
|
||||
_pool = ConnectionPool(database_path)
|
||||
return _pool
|
||||
|
||||
def get_connection():
|
||||
"""获取数据库连接"""
|
||||
return get_pool().get_connection()
|
59
utils/logger.py
Normal file
59
utils/logger.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import logging
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import os
|
||||
|
||||
def setup_logging(app=None, log_level=logging.INFO):
|
||||
"""
|
||||
设置应用程序的日志记录
|
||||
"""
|
||||
# 创建logs目录
|
||||
if not os.path.exists('logs'):
|
||||
os.makedirs('logs')
|
||||
|
||||
# 设置日志格式
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# 文件处理器 - 应用日志
|
||||
file_handler = RotatingFileHandler(
|
||||
'logs/app.log',
|
||||
maxBytes=10*1024*1024, # 10MB
|
||||
backupCount=5
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(log_level)
|
||||
|
||||
# 错误日志处理器
|
||||
error_handler = RotatingFileHandler(
|
||||
'logs/error.log',
|
||||
maxBytes=10*1024*1024, # 10MB
|
||||
backupCount=5
|
||||
)
|
||||
error_handler.setFormatter(formatter)
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(log_level)
|
||||
|
||||
# 配置根日志记录器
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(log_level)
|
||||
root_logger.addHandler(file_handler)
|
||||
root_logger.addHandler(error_handler)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 如果是Flask应用,也配置Flask的日志
|
||||
if app:
|
||||
app.logger.addHandler(file_handler)
|
||||
app.logger.addHandler(error_handler)
|
||||
app.logger.setLevel(log_level)
|
||||
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
def get_logger(name):
|
||||
"""获取指定名称的日志记录器"""
|
||||
return logging.getLogger(name)
|
173
utils/performance_monitor.py
Normal file
173
utils/performance_monitor.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
性能监控模块
|
||||
用于监控应用程序的性能指标
|
||||
"""
|
||||
import time
|
||||
import threading
|
||||
from functools import wraps
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceMetric:
|
||||
"""性能指标数据类"""
|
||||
name: str
|
||||
start_time: float
|
||||
end_time: float = 0
|
||||
duration: float = 0
|
||||
memory_before: float = 0
|
||||
memory_after: float = 0
|
||||
success: bool = True
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
class PerformanceMonitor:
|
||||
"""性能监控器"""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics: List[PerformanceMetric] = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def start_monitoring(self, name: str) -> PerformanceMetric:
|
||||
"""开始监控一个操作"""
|
||||
metric = PerformanceMetric(
|
||||
name=name,
|
||||
start_time=time.time(),
|
||||
memory_before=self.get_memory_usage()
|
||||
)
|
||||
return metric
|
||||
|
||||
def end_monitoring(self, metric: PerformanceMetric, success: bool = True, error_message: str = ""):
|
||||
"""结束监控操作"""
|
||||
metric.end_time = time.time()
|
||||
metric.duration = metric.end_time - metric.start_time
|
||||
metric.memory_after = self.get_memory_usage()
|
||||
metric.success = success
|
||||
metric.error_message = error_message
|
||||
|
||||
with self.lock:
|
||||
self.metrics.append(metric)
|
||||
|
||||
# 保持最近1000条记录
|
||||
if len(self.metrics) > 1000:
|
||||
self.metrics = self.metrics[-1000:]
|
||||
|
||||
logger.debug(f"性能监控: {metric.name} - 耗时: {metric.duration:.3f}s, "
|
||||
f"内存变化: {metric.memory_after - metric.memory_before:.2f}MB")
|
||||
|
||||
def get_memory_usage(self) -> float:
|
||||
"""获取当前内存使用量(MB)"""
|
||||
try:
|
||||
# 简单的内存使用量估算
|
||||
# 在Windows上,可以使用其他方法,这里先返回0
|
||||
return 0.0
|
||||
except:
|
||||
return 0.0
|
||||
|
||||
def get_stats(self, operation_name: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""获取性能统计信息"""
|
||||
with self.lock:
|
||||
filtered_metrics = self.metrics
|
||||
if operation_name:
|
||||
filtered_metrics = [m for m in self.metrics if m.name == operation_name]
|
||||
|
||||
if not filtered_metrics:
|
||||
return {}
|
||||
|
||||
durations = [m.duration for m in filtered_metrics if m.success]
|
||||
success_count = len([m for m in filtered_metrics if m.success])
|
||||
error_count = len([m for m in filtered_metrics if not m.success])
|
||||
|
||||
stats = {
|
||||
'operation_name': operation_name or 'All Operations',
|
||||
'total_calls': len(filtered_metrics),
|
||||
'success_calls': success_count,
|
||||
'error_calls': error_count,
|
||||
'success_rate': f"{(success_count / len(filtered_metrics) * 100):.2f}%" if filtered_metrics else "0%",
|
||||
'avg_duration': f"{(sum(durations) / len(durations)):.3f}s" if durations else "0s",
|
||||
'min_duration': f"{min(durations):.3f}s" if durations else "0s",
|
||||
'max_duration': f"{max(durations):.3f}s" if durations else "0s",
|
||||
'current_memory': f"{self.get_memory_usage():.2f}MB"
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def get_recent_errors(self, count: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取最近的错误"""
|
||||
with self.lock:
|
||||
error_metrics = [m for m in self.metrics if not m.success][-count:]
|
||||
return [
|
||||
{
|
||||
'name': m.name,
|
||||
'time': datetime.fromtimestamp(m.start_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'duration': f"{m.duration:.3f}s",
|
||||
'error': m.error_message
|
||||
}
|
||||
for m in error_metrics
|
||||
]
|
||||
|
||||
def clear_metrics(self):
|
||||
"""清空监控数据"""
|
||||
with self.lock:
|
||||
self.metrics.clear()
|
||||
logger.info("清空性能监控数据")
|
||||
|
||||
|
||||
# 全局性能监控器
|
||||
_performance_monitor = None
|
||||
|
||||
|
||||
def get_performance_monitor() -> PerformanceMonitor:
|
||||
"""获取全局性能监控器实例"""
|
||||
global _performance_monitor
|
||||
if _performance_monitor is None:
|
||||
_performance_monitor = PerformanceMonitor()
|
||||
return _performance_monitor
|
||||
|
||||
|
||||
def monitor_performance(operation_name: Optional[str] = None):
|
||||
"""性能监控装饰器"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
monitor = get_performance_monitor()
|
||||
name = operation_name or f"{func.__module__}.{func.__name__}"
|
||||
|
||||
metric = monitor.start_monitoring(name)
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
monitor.end_monitoring(metric, success=True)
|
||||
return result
|
||||
except Exception as e:
|
||||
monitor.end_monitoring(metric, success=False, error_message=str(e))
|
||||
raise
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def timing_context(operation_name: str):
|
||||
"""性能监控上下文管理器"""
|
||||
class TimingContext:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.monitor = get_performance_monitor()
|
||||
self.metric: Optional[PerformanceMetric] = None
|
||||
|
||||
def __enter__(self):
|
||||
self.metric = self.monitor.start_monitoring(self.name)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.metric:
|
||||
if exc_type is None:
|
||||
self.monitor.end_monitoring(self.metric, success=True)
|
||||
else:
|
||||
self.monitor.end_monitoring(self.metric, success=False, error_message=str(exc_val))
|
||||
|
||||
return TimingContext(operation_name)
|
34
utils/security.py
Normal file
34
utils/security.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
import hmac
|
||||
|
||||
from typing import Optional
|
||||
|
||||
def hash_password(password: str, salt: Optional[str] = None) -> tuple[str, str]:
|
||||
"""
|
||||
哈希密码并返回哈希值和盐值
|
||||
"""
|
||||
if salt is None:
|
||||
salt = secrets.token_hex(32)
|
||||
|
||||
password_hash = hashlib.pbkdf2_hmac(
|
||||
'sha256',
|
||||
password.encode('utf-8'),
|
||||
salt.encode('utf-8'),
|
||||
100000 # 迭代次数
|
||||
)
|
||||
|
||||
return password_hash.hex(), salt
|
||||
|
||||
def verify_password(password: str, hashed_password: str, salt: str) -> bool:
|
||||
"""
|
||||
验证密码是否正确
|
||||
"""
|
||||
test_hash, _ = hash_password(password, salt)
|
||||
return hmac.compare_digest(test_hash, hashed_password)
|
||||
|
||||
def generate_session_token() -> str:
|
||||
"""
|
||||
生成安全的会话令牌
|
||||
"""
|
||||
return secrets.token_urlsafe(32)
|
Reference in New Issue
Block a user