mirror of
https://github.com/Kakune55/ComiPy.git
synced 2025-09-15 19:59:39 +08:00
feat(file): 优化文件处理和缓存机制
- 重构文件处理逻辑,提高性能和可维护性 - 增加缓存机制,减少重复读取和处理 - 改进错误处理和日志记录 - 优化缩略图生成算法 - 添加性能监控和测试依赖
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -3,3 +3,5 @@ data
|
||||
input
|
||||
test.py
|
||||
conf/app.ini
|
||||
logs/app.log
|
||||
logs/error.log
|
||||
|
@@ -37,8 +37,8 @@ def init():
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
time INT NOT NULL,
|
||||
bookid TEXT NOT NULL,
|
||||
from_uid INTEGAR NOT NULL,
|
||||
score INT NOT NULL,
|
||||
from_uid INTEGER NOT NULL,
|
||||
score TEXT NOT NULL,
|
||||
content TEXT
|
||||
);
|
||||
"""
|
||||
|
296
file.py
296
file.py
@@ -1,86 +1,266 @@
|
||||
import shutil, os, zipfile, io, cv2, numpy as np
|
||||
import hashlib
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
import db.file, app_conf
|
||||
from utils.logger import get_logger
|
||||
from utils.cache_manager import get_cache_manager, cache_image
|
||||
from utils.performance_monitor import monitor_performance, timing_context
|
||||
|
||||
app_conf = app_conf.conf()
|
||||
# 获取配置对象
|
||||
conf = app_conf.conf()
|
||||
logger = get_logger(__name__)
|
||||
cache_manager = get_cache_manager()
|
||||
|
||||
# 内存缓存 - 存储最近访问的ZIP文件列表
|
||||
_zip_cache = {}
|
||||
_cache_timeout = 300 # 5分钟缓存超时
|
||||
|
||||
|
||||
def init():
|
||||
"""初始化文件目录"""
|
||||
paths = ("inputdir", "storedir", "tmpdir")
|
||||
for path in paths:
|
||||
try:
|
||||
os.makedirs(app_conf.get("file", path))
|
||||
dir_path = Path(conf.get("file", path))
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"创建目录: {dir_path}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.error(f"创建目录失败 {path}: {e}")
|
||||
|
||||
|
||||
def auotLoadFile():
|
||||
fileList = os.listdir(app_conf.get("file", "inputdir"))
|
||||
for item in fileList:
|
||||
if zipfile.is_zipfile(
|
||||
app_conf.get("file", "inputdir") + "/" + item
|
||||
): # 判断是否为压缩包
|
||||
with zipfile.ZipFile(
|
||||
app_conf.get("file", "inputdir") + "/" + item, "r"
|
||||
) as zip_ref:
|
||||
db.file.new(item, len(zip_ref.namelist())) # 添加数据库记录 移动到存储
|
||||
shutil.move(
|
||||
app_conf.get("file", "inputdir") + "/" + item,
|
||||
app_conf.get("file", "storedir") + "/" + item,
|
||||
)
|
||||
print("已添加 " + item)
|
||||
else:
|
||||
print("不符合条件 " + item)
|
||||
|
||||
|
||||
def raedZip(bookid: str, index: int):
|
||||
bookinfo = db.file.searchByid(bookid)
|
||||
zippath = app_conf.get("file", "storedir") + "/" + bookinfo[0][2]
|
||||
|
||||
@monitor_performance("file.get_image_files_from_zip")
|
||||
def get_image_files_from_zip(zip_path: str) -> tuple:
|
||||
"""
|
||||
从ZIP文件中获取图片文件列表,使用缓存提高性能
|
||||
返回: (image_files_list, cache_key)
|
||||
"""
|
||||
cache_key = f"{zip_path}_{os.path.getmtime(zip_path)}"
|
||||
current_time = time.time()
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in _zip_cache:
|
||||
cache_data = _zip_cache[cache_key]
|
||||
if current_time - cache_data['timestamp'] < _cache_timeout:
|
||||
logger.debug(f"使用缓存的ZIP文件列表: {zip_path}")
|
||||
return cache_data['files'], cache_key
|
||||
|
||||
# 读取ZIP文件
|
||||
try:
|
||||
# 创建一个ZipFile对象
|
||||
with zipfile.ZipFile(zippath, "r") as zip_ref:
|
||||
# 获取图片文件列表
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
image_files = [
|
||||
file
|
||||
for file in zip_ref.namelist()
|
||||
file for file in zip_ref.namelist()
|
||||
if file.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"))
|
||||
]
|
||||
|
||||
# 缓存结果
|
||||
_zip_cache[cache_key] = {
|
||||
'files': image_files,
|
||||
'timestamp': current_time
|
||||
}
|
||||
|
||||
# 清理过期缓存
|
||||
_cleanup_cache()
|
||||
|
||||
logger.debug(f"缓存ZIP文件列表: {zip_path}, 图片数量: {len(image_files)}")
|
||||
return image_files, cache_key
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"读取ZIP文件失败 {zip_path}: {e}")
|
||||
return [], cache_key
|
||||
|
||||
if not image_files:
|
||||
return "not imgfile in zip", ""
|
||||
|
||||
if int(index) > len(image_files):
|
||||
return "404 not found", ""
|
||||
def _cleanup_cache():
|
||||
"""清理过期缓存"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, data in _zip_cache.items()
|
||||
if current_time - data['timestamp'] > _cache_timeout
|
||||
]
|
||||
for key in expired_keys:
|
||||
del _zip_cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"清理过期缓存: {len(expired_keys)} 项")
|
||||
|
||||
# 假设我们只提取图片文件
|
||||
|
||||
@monitor_performance("file.autoLoadFile")
|
||||
def autoLoadFile():
|
||||
"""自动加载文件,优化路径处理和错误处理"""
|
||||
input_dir = Path(conf.get("file", "inputdir"))
|
||||
store_dir = Path(conf.get("file", "storedir"))
|
||||
|
||||
if not input_dir.exists():
|
||||
logger.warning(f"输入目录不存在: {input_dir}")
|
||||
return
|
||||
|
||||
file_list = []
|
||||
try:
|
||||
file_list = [f for f in input_dir.iterdir() if f.is_file()]
|
||||
except Exception as e:
|
||||
logger.error(f"读取输入目录失败: {e}")
|
||||
return
|
||||
|
||||
processed_count = 0
|
||||
for file_path in file_list:
|
||||
try:
|
||||
if zipfile.is_zipfile(file_path):
|
||||
with zipfile.ZipFile(file_path, "r") as zip_ref:
|
||||
page_count = len([f for f in zip_ref.namelist()
|
||||
if f.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"))])
|
||||
if page_count > 0:
|
||||
db.file.new(file_path.name, page_count)
|
||||
|
||||
# 移动文件到存储目录
|
||||
target_path = store_dir / file_path.name
|
||||
shutil.move(str(file_path), str(target_path))
|
||||
|
||||
logger.info(f"已添加漫画: {file_path.name}, 页数: {page_count}")
|
||||
processed_count += 1
|
||||
else:
|
||||
logger.warning(f"ZIP文件中没有图片: {file_path.name}")
|
||||
else:
|
||||
logger.info(f"非ZIP文件,跳过: {file_path.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理文件失败 {file_path.name}: {e}")
|
||||
|
||||
logger.info(f"自动加载完成,处理了 {processed_count} 个文件")
|
||||
|
||||
|
||||
@monitor_performance("file.readZip")
|
||||
def readZip(bookid: str, index: int) -> tuple:
|
||||
"""
|
||||
从ZIP文件中读取指定索引的图片
|
||||
优化:使用缓存的文件列表,改进错误处理
|
||||
返回: (image_data, filename) 或 (error_message, "")
|
||||
"""
|
||||
try:
|
||||
bookinfo = db.file.searchByid(bookid)
|
||||
if not bookinfo:
|
||||
logger.warning(f"未找到书籍ID: {bookid}")
|
||||
return "Book not found", ""
|
||||
|
||||
zip_path = Path(conf.get("file", "storedir")) / bookinfo[0][2]
|
||||
|
||||
if not zip_path.exists():
|
||||
logger.error(f"ZIP文件不存在: {zip_path}")
|
||||
return "ZIP file not found", ""
|
||||
|
||||
# 使用缓存获取图片文件列表
|
||||
image_files, _ = get_image_files_from_zip(str(zip_path))
|
||||
|
||||
if not image_files:
|
||||
logger.warning(f"ZIP文件中没有图片: {zip_path}")
|
||||
return "No image files in zip", ""
|
||||
|
||||
if int(index) >= len(image_files):
|
||||
logger.warning(f"图片索引超出范围: {index}, 总数: {len(image_files)}")
|
||||
return "Image index out of range", ""
|
||||
|
||||
# 读取指定图片
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
image_filename = image_files[int(index)]
|
||||
|
||||
# 读取图片数据
|
||||
image_data = zip_ref.read(image_filename)
|
||||
zip_ref.close()
|
||||
|
||||
logger.debug(f"读取图片: {bookid}/{index} -> {image_filename}")
|
||||
return image_data, image_filename
|
||||
|
||||
except zipfile.BadZipFile: # 异常处理
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
logger.error(f"损坏的ZIP文件: {bookid}")
|
||||
return "Bad ZipFile", ""
|
||||
except Exception as e:
|
||||
return str(e), ""
|
||||
logger.error(f"读取ZIP文件失败 {bookid}/{index}: {e}")
|
||||
return f"Error: {str(e)}", ""
|
||||
|
||||
|
||||
def thumbnail(input, minSize: int = 600, encode:str="webp"):
|
||||
img = cv2.imdecode(np.frombuffer(input, np.uint8), cv2.IMREAD_COLOR)
|
||||
height = img.shape[0] # 图片高度
|
||||
width = img.shape[1] # 图片宽度
|
||||
if minSize < np.amin((height,width)):
|
||||
if height > width:
|
||||
newshape = (minSize, int(minSize / width * height))
|
||||
@lru_cache(maxsize=128)
|
||||
def _get_image_hash(image_data: bytes) -> str:
|
||||
"""生成图片数据的哈希值用于缓存"""
|
||||
return hashlib.md5(image_data).hexdigest()
|
||||
|
||||
|
||||
@cache_image
|
||||
def thumbnail(input_data: bytes, min_size: int = 600, encode: str = "webp", quality: int = 75) -> bytes:
|
||||
"""
|
||||
生成缩略图,优化编码逻辑和性能
|
||||
"""
|
||||
if not input_data:
|
||||
logger.warning("输入图片数据为空")
|
||||
return input_data
|
||||
|
||||
try:
|
||||
# 解码图片
|
||||
img = cv2.imdecode(np.frombuffer(input_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
logger.warning("无法解码图片数据")
|
||||
return input_data
|
||||
|
||||
height, width = img.shape[:2]
|
||||
logger.debug(f"原始图片尺寸: {width}x{height}")
|
||||
|
||||
# 判断是否需要缩放
|
||||
min_dimension = min(height, width)
|
||||
if min_size < min_dimension:
|
||||
# 计算新尺寸
|
||||
if height > width:
|
||||
new_width = min_size
|
||||
new_height = int(min_size * height / width)
|
||||
else:
|
||||
new_height = min_size
|
||||
new_width = int(min_size * width / height)
|
||||
|
||||
img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||
logger.debug(f"缩放后图片尺寸: {new_width}x{new_height}")
|
||||
|
||||
# 编码图片
|
||||
if encode.lower() == "webp":
|
||||
success, encoded_image = cv2.imencode(
|
||||
".webp", img, [cv2.IMWRITE_WEBP_QUALITY, quality]
|
||||
)
|
||||
elif encode.lower() in ("jpg", "jpeg"):
|
||||
success, encoded_image = cv2.imencode(
|
||||
".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
)
|
||||
elif encode.lower() == "png":
|
||||
success, encoded_image = cv2.imencode(
|
||||
".png", img, [cv2.IMWRITE_PNG_COMPRESSION, 6]
|
||||
)
|
||||
else:
|
||||
newshape = (int(minSize / height * width), minSize)
|
||||
img = cv2.resize(img, newshape)
|
||||
if encode == "webp":
|
||||
success, encoded_image = cv2.imencode(".webp", img, [cv2.IMWRITE_WEBP_QUALITY, 75])
|
||||
elif encode == "jpg" or "jpeg":
|
||||
success, encoded_image = cv2.imencode(".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, 75])
|
||||
else:
|
||||
return input
|
||||
return encoded_image.tobytes()
|
||||
logger.warning(f"不支持的编码格式: {encode}, 返回原始数据")
|
||||
return input_data
|
||||
|
||||
if not success:
|
||||
logger.error(f"图片编码失败: {encode}")
|
||||
return input_data
|
||||
|
||||
result = encoded_image.tobytes()
|
||||
logger.debug(f"图片处理完成: 原始 {len(input_data)} bytes -> 处理后 {len(result)} bytes")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理异常: {e}")
|
||||
return input_data
|
||||
|
||||
|
||||
def get_zip_image_count(bookid: str) -> int:
|
||||
"""
|
||||
获取ZIP文件中的图片数量,使用缓存
|
||||
"""
|
||||
try:
|
||||
bookinfo = db.file.searchByid(bookid)
|
||||
if not bookinfo:
|
||||
return 0
|
||||
|
||||
zip_path = Path(conf.get("file", "storedir")) / bookinfo[0][2]
|
||||
if not zip_path.exists():
|
||||
return 0
|
||||
|
||||
image_files, _ = get_image_files_from_zip(str(zip_path))
|
||||
return len(image_files)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片数量失败 {bookid}: {e}")
|
||||
return 0
|
||||
|
18
main.py
18
main.py
@@ -7,21 +7,33 @@ from router.api_Img import api_Img_bp
|
||||
from router.page import page_bp
|
||||
from router.admin_page import admin_page_bp
|
||||
from router.api_comment import comment_api_bp
|
||||
from router.performance_api import performance_bp
|
||||
from utils.logger import setup_logging
|
||||
from utils.performance_monitor import get_performance_monitor
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
conf = app_conf.conf()
|
||||
|
||||
def appinit():
|
||||
"""应用初始化,集成日志和性能监控"""
|
||||
# 设置日志
|
||||
setup_logging(app)
|
||||
|
||||
# 初始化文件系统和数据库
|
||||
file.init()
|
||||
db.util.init()
|
||||
file.auotLoadFile()
|
||||
|
||||
file.autoLoadFile()
|
||||
|
||||
# 启动性能监控
|
||||
monitor = get_performance_monitor()
|
||||
app.logger.info("应用初始化完成,性能监控已启动")
|
||||
|
||||
# 注册蓝图
|
||||
app.register_blueprint(api_Img_bp)
|
||||
app.register_blueprint(page_bp)
|
||||
app.register_blueprint(admin_page_bp)
|
||||
app.register_blueprint(comment_api_bp)
|
||||
app.register_blueprint(performance_bp)
|
||||
|
||||
if __name__ == "__main__":
|
||||
appinit()
|
||||
|
@@ -1,4 +1,12 @@
|
||||
shortuuid
|
||||
flask
|
||||
flask>=2.3.0
|
||||
opencv-python
|
||||
opencv-python-headless
|
||||
opencv-python-headless
|
||||
werkzeug>=2.3.0
|
||||
Pillow>=9.0.0
|
||||
python-dotenv
|
||||
flask-limiter
|
||||
# 性能测试依赖(可选)
|
||||
requests>=2.25.0
|
||||
# 如果需要更好的性能监控,可以添加
|
||||
# psutil>=5.8.0
|
@@ -1,6 +1,6 @@
|
||||
from flask import *
|
||||
from flask import Blueprint
|
||||
from flask import Blueprint, request, abort, make_response
|
||||
import db.file , file, gc , app_conf
|
||||
from utils.performance_monitor import timing_context
|
||||
|
||||
api_Img_bp = Blueprint("api_Img_bp", __name__)
|
||||
|
||||
@@ -11,21 +11,30 @@ fullSize = conf.getint("img", "fullSize")
|
||||
|
||||
@api_Img_bp.route("/api/img/<bookid>/<index>")
|
||||
def img(bookid, index): # 图片接口
|
||||
if request.cookies.get("islogin") is None:
|
||||
return abort(403)
|
||||
if len(db.file.searchByid(bookid)) == 0:
|
||||
return abort(404)
|
||||
# 设置响应类型为图片
|
||||
data, filename = file.raedZip(bookid, index)
|
||||
if isinstance(data, str):
|
||||
abort(404)
|
||||
if request.args.get("mini") == "yes":
|
||||
data = file.thumbnail(data,miniSize,encode=imgencode)
|
||||
else:
|
||||
data = file.thumbnail(data,fullSize,encode=imgencode)
|
||||
response = make_response(data) # 读取文件
|
||||
del data
|
||||
response.headers.set("Content-Type",f"image/{imgencode}")
|
||||
response.headers.set("Content-Disposition", "inline", filename=filename)
|
||||
gc.collect()
|
||||
return response
|
||||
with timing_context(f"api.img.{bookid}.{index}"):
|
||||
if request.cookies.get("islogin") is None:
|
||||
return abort(403)
|
||||
if len(db.file.searchByid(bookid)) == 0:
|
||||
return abort(404)
|
||||
|
||||
# 读取图片数据
|
||||
data, filename = file.readZip(bookid, index)
|
||||
if isinstance(data, str):
|
||||
abort(404)
|
||||
|
||||
# 处理图片尺寸
|
||||
if request.args.get("mini") == "yes":
|
||||
data = file.thumbnail(data, miniSize, encode=imgencode)
|
||||
else:
|
||||
data = file.thumbnail(data, fullSize, encode=imgencode)
|
||||
|
||||
# 创建响应
|
||||
response = make_response(data)
|
||||
del data # 及时释放内存
|
||||
|
||||
response.headers.set("Content-Type", f"image/{imgencode}")
|
||||
response.headers.set("Content-Disposition", "inline", filename=filename)
|
||||
response.headers.set("Cache-Control", "public, max-age=3600") # 添加缓存头
|
||||
|
||||
gc.collect() # 强制垃圾回收
|
||||
return response
|
||||
|
@@ -14,7 +14,7 @@ def overview(page): # 概览
|
||||
if request.cookies.get("islogin") is None: # 验证登录状态
|
||||
return redirect("/")
|
||||
metaDataList = db.file.getMetadata(
|
||||
(page - 1) * 20, page * 20, request.args.get("search")
|
||||
(page - 1) * 20, page * 20, request.args.get("search", "")
|
||||
)
|
||||
for item in metaDataList:
|
||||
item[2] = item[2][:-4] # 去除文件扩展名
|
||||
@@ -89,7 +89,13 @@ def upload_file():
|
||||
uploaded_file = request.files.getlist("files[]") # 获取上传的文件列表
|
||||
print(uploaded_file)
|
||||
for fileitem in uploaded_file:
|
||||
if fileitem.filename != "":
|
||||
fileitem.save(conf.get("file", "inputdir") + "/" + fileitem.filename)
|
||||
file.auotLoadFile()
|
||||
if fileitem.filename and fileitem.filename != "":
|
||||
input_dir = conf.get("file", "inputdir")
|
||||
if not input_dir:
|
||||
return "Input directory is not configured.", 500
|
||||
import os
|
||||
if input_dir is None:
|
||||
return "Input directory is not configured.", 500
|
||||
fileitem.save(os.path.join(input_dir, fileitem.filename))
|
||||
file.autoLoadFile()
|
||||
return "success"
|
||||
|
61
router/performance_api.py
Normal file
61
router/performance_api.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from flask import Blueprint, render_template, jsonify, request
|
||||
from utils.performance_monitor import get_performance_monitor
|
||||
from utils.cache_manager import get_cache_manager
|
||||
|
||||
performance_bp = Blueprint("performance_bp", __name__)
|
||||
|
||||
|
||||
@performance_bp.route("/api/performance/stats")
|
||||
def performance_stats():
|
||||
"""获取性能统计信息"""
|
||||
if request.cookies.get("islogin") is None:
|
||||
return jsonify({"error": "Unauthorized"}), 403
|
||||
|
||||
monitor = get_performance_monitor()
|
||||
cache_manager = get_cache_manager()
|
||||
|
||||
operation_name = request.args.get("operation")
|
||||
|
||||
stats = {
|
||||
"performance": monitor.get_stats(operation_name),
|
||||
"cache": cache_manager.get_stats(),
|
||||
"recent_errors": monitor.get_recent_errors(5)
|
||||
}
|
||||
|
||||
return jsonify(stats)
|
||||
|
||||
|
||||
@performance_bp.route("/api/performance/clear")
|
||||
def clear_performance_data():
|
||||
"""清空性能监控数据"""
|
||||
if request.cookies.get("islogin") is None:
|
||||
return jsonify({"error": "Unauthorized"}), 403
|
||||
|
||||
monitor = get_performance_monitor()
|
||||
monitor.clear_metrics()
|
||||
|
||||
return jsonify({"message": "Performance data cleared"})
|
||||
|
||||
|
||||
@performance_bp.route("/api/cache/clear")
|
||||
def clear_cache():
|
||||
"""清空缓存"""
|
||||
if request.cookies.get("islogin") is None:
|
||||
return jsonify({"error": "Unauthorized"}), 403
|
||||
|
||||
cache_manager = get_cache_manager()
|
||||
cache_manager.clear()
|
||||
|
||||
return jsonify({"message": "Cache cleared"})
|
||||
|
||||
|
||||
@performance_bp.route("/api/cache/cleanup")
|
||||
def cleanup_cache():
|
||||
"""清理过期缓存"""
|
||||
if request.cookies.get("islogin") is None:
|
||||
return jsonify({"error": "Unauthorized"}), 403
|
||||
|
||||
cache_manager = get_cache_manager()
|
||||
cache_manager.cleanup_expired()
|
||||
|
||||
return jsonify({"message": "Expired cache cleaned up"})
|
@@ -82,11 +82,6 @@
|
||||
border-color: #1557b0;
|
||||
}
|
||||
</style>
|
||||
|
||||
.form-group button:hover {
|
||||
background-color: #4cae4c;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
|
236
utils/cache_manager.py
Normal file
236
utils/cache_manager.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
缓存管理器 - 用于图片和数据缓存
|
||||
提供内存缓存和磁盘缓存功能
|
||||
"""
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
import pickle
|
||||
|
||||
from utils.logger import get_logger
|
||||
import app_conf
|
||||
|
||||
logger = get_logger(__name__)
|
||||
conf = app_conf.conf()
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""缓存管理器"""
|
||||
|
||||
def __init__(self, cache_dir: Optional[str] = None, max_memory_size: int = 100):
|
||||
"""
|
||||
初始化缓存管理器
|
||||
|
||||
Args:
|
||||
cache_dir: 磁盘缓存目录
|
||||
max_memory_size: 内存缓存最大条目数
|
||||
"""
|
||||
self.cache_dir = Path(cache_dir or conf.get("file", "tmpdir")) / "cache"
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.memory_cache = {}
|
||||
self.max_memory_size = max_memory_size
|
||||
self.cache_access_times = {}
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# 缓存统计
|
||||
self.stats = {
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'memory_hits': 0,
|
||||
'disk_hits': 0
|
||||
}
|
||||
|
||||
logger.info(f"缓存管理器初始化: 目录={self.cache_dir}, 内存限制={max_memory_size}")
|
||||
|
||||
def _generate_key(self, *args) -> str:
|
||||
"""生成缓存键"""
|
||||
key_string = "_".join(str(arg) for arg in args)
|
||||
return hashlib.md5(key_string.encode('utf-8')).hexdigest()
|
||||
|
||||
def _cleanup_memory_cache(self):
|
||||
"""清理内存缓存,移除最久未访问的项目"""
|
||||
if len(self.memory_cache) <= self.max_memory_size:
|
||||
return
|
||||
|
||||
# 按访问时间排序,移除最旧的项目
|
||||
sorted_items = sorted(
|
||||
self.cache_access_times.items(),
|
||||
key=lambda x: x[1]
|
||||
)
|
||||
|
||||
# 移除最旧的20%
|
||||
remove_count = len(self.memory_cache) - self.max_memory_size + 1
|
||||
for key, _ in sorted_items[:remove_count]:
|
||||
if key in self.memory_cache:
|
||||
del self.memory_cache[key]
|
||||
del self.cache_access_times[key]
|
||||
|
||||
logger.debug(f"清理内存缓存: 移除 {remove_count} 项")
|
||||
|
||||
def get(self, key: str, default=None) -> Any:
|
||||
"""获取缓存数据"""
|
||||
with self.lock:
|
||||
current_time = time.time()
|
||||
|
||||
# 检查内存缓存
|
||||
if key in self.memory_cache:
|
||||
self.cache_access_times[key] = current_time
|
||||
self.stats['hits'] += 1
|
||||
self.stats['memory_hits'] += 1
|
||||
logger.debug(f"内存缓存命中: {key}")
|
||||
return self.memory_cache[key]
|
||||
|
||||
# 检查磁盘缓存
|
||||
cache_file = self.cache_dir / f"{key}.cache"
|
||||
if cache_file.exists():
|
||||
try:
|
||||
with open(cache_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
# 将数据加载到内存缓存
|
||||
self.memory_cache[key] = data
|
||||
self.cache_access_times[key] = current_time
|
||||
self._cleanup_memory_cache()
|
||||
|
||||
self.stats['hits'] += 1
|
||||
self.stats['disk_hits'] += 1
|
||||
logger.debug(f"磁盘缓存命中: {key}")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"读取磁盘缓存失败 {key}: {e}")
|
||||
# 删除损坏的缓存文件
|
||||
try:
|
||||
cache_file.unlink()
|
||||
except:
|
||||
pass
|
||||
|
||||
self.stats['misses'] += 1
|
||||
logger.debug(f"缓存未命中: {key}")
|
||||
return default
|
||||
|
||||
def set(self, key: str, value: Any, disk_cache: bool = True):
|
||||
"""设置缓存数据"""
|
||||
with self.lock:
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到内存缓存
|
||||
self.memory_cache[key] = value
|
||||
self.cache_access_times[key] = current_time
|
||||
self._cleanup_memory_cache()
|
||||
|
||||
# 存储到磁盘缓存
|
||||
if disk_cache:
|
||||
try:
|
||||
cache_file = self.cache_dir / f"{key}.cache"
|
||||
with open(cache_file, 'wb') as f:
|
||||
pickle.dump(value, f)
|
||||
logger.debug(f"数据已缓存: {key}")
|
||||
except Exception as e:
|
||||
logger.error(f"写入磁盘缓存失败 {key}: {e}")
|
||||
|
||||
def delete(self, key: str):
|
||||
"""删除缓存数据"""
|
||||
with self.lock:
|
||||
# 删除内存缓存
|
||||
if key in self.memory_cache:
|
||||
del self.memory_cache[key]
|
||||
del self.cache_access_times[key]
|
||||
|
||||
# 删除磁盘缓存
|
||||
cache_file = self.cache_dir / f"{key}.cache"
|
||||
if cache_file.exists():
|
||||
try:
|
||||
cache_file.unlink()
|
||||
logger.debug(f"删除缓存: {key}")
|
||||
except Exception as e:
|
||||
logger.error(f"删除磁盘缓存失败 {key}: {e}")
|
||||
|
||||
def clear(self):
|
||||
"""清空所有缓存"""
|
||||
with self.lock:
|
||||
# 清空内存缓存
|
||||
self.memory_cache.clear()
|
||||
self.cache_access_times.clear()
|
||||
|
||||
# 清空磁盘缓存
|
||||
try:
|
||||
for cache_file in self.cache_dir.glob("*.cache"):
|
||||
cache_file.unlink()
|
||||
logger.info("清空所有缓存")
|
||||
except Exception as e:
|
||||
logger.error(f"清空磁盘缓存失败: {e}")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
with self.lock:
|
||||
total_requests = self.stats['hits'] + self.stats['misses']
|
||||
hit_rate = (self.stats['hits'] / total_requests * 100) if total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
'total_requests': total_requests,
|
||||
'hits': self.stats['hits'],
|
||||
'misses': self.stats['misses'],
|
||||
'hit_rate': f"{hit_rate:.2f}%",
|
||||
'memory_hits': self.stats['memory_hits'],
|
||||
'disk_hits': self.stats['disk_hits'],
|
||||
'memory_cache_size': len(self.memory_cache),
|
||||
'disk_cache_files': len(list(self.cache_dir.glob("*.cache")))
|
||||
}
|
||||
|
||||
def cleanup_expired(self, max_age_hours: int = 24):
|
||||
"""清理过期的磁盘缓存文件"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
max_age_seconds = max_age_hours * 3600
|
||||
removed_count = 0
|
||||
|
||||
for cache_file in self.cache_dir.glob("*.cache"):
|
||||
if current_time - cache_file.stat().st_mtime > max_age_seconds:
|
||||
cache_file.unlink()
|
||||
removed_count += 1
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"清理过期缓存文件: {removed_count} 个")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理过期缓存失败: {e}")
|
||||
|
||||
|
||||
# 全局缓存管理器实例
|
||||
_cache_manager = None
|
||||
|
||||
|
||||
def get_cache_manager() -> CacheManager:
|
||||
"""获取全局缓存管理器实例"""
|
||||
global _cache_manager
|
||||
if _cache_manager is None:
|
||||
_cache_manager = CacheManager()
|
||||
return _cache_manager
|
||||
|
||||
|
||||
def cache_image(func):
|
||||
"""图片缓存装饰器"""
|
||||
def wrapper(*args, **kwargs):
|
||||
cache_manager = get_cache_manager()
|
||||
|
||||
# 生成缓存键
|
||||
cache_key = cache_manager._generate_key(func.__name__, *args, *kwargs.items())
|
||||
|
||||
# 尝试从缓存获取
|
||||
cached_result = cache_manager.get(cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# 执行函数并缓存结果
|
||||
result = func(*args, **kwargs)
|
||||
if result: # 只缓存有效结果
|
||||
cache_manager.set(cache_key, result)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
104
utils/config_validator.py
Normal file
104
utils/config_validator.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
import app_conf
|
||||
|
||||
def validate_config() -> Dict[str, Any]:
|
||||
"""
|
||||
验证配置文件的有效性
|
||||
返回验证结果字典
|
||||
"""
|
||||
conf = app_conf.conf()
|
||||
issues = []
|
||||
warnings = []
|
||||
|
||||
# 检查必需的配置项
|
||||
required_sections = {
|
||||
'server': ['port', 'host'],
|
||||
'user': ['username', 'password'],
|
||||
'database': ['path'],
|
||||
'file': ['inputdir', 'storedir', 'tmpdir'],
|
||||
'img': ['encode', 'miniSize', 'fullSize']
|
||||
}
|
||||
|
||||
for section, keys in required_sections.items():
|
||||
if not conf.has_section(section):
|
||||
issues.append(f"缺少配置节: [{section}]")
|
||||
continue
|
||||
|
||||
for key in keys:
|
||||
if not conf.has_option(section, key):
|
||||
issues.append(f"缺少配置项: [{section}].{key}")
|
||||
|
||||
# 检查安全性问题
|
||||
if conf.has_section('user'):
|
||||
username = conf.get('user', 'username', fallback='')
|
||||
password = conf.get('user', 'password', fallback='')
|
||||
|
||||
if username == 'admin' and password == 'admin':
|
||||
warnings.append("使用默认用户名和密码不安全,建议修改")
|
||||
|
||||
if len(password) < 8:
|
||||
warnings.append("密码过于简单,建议使用8位以上的复杂密码")
|
||||
|
||||
# 检查端口配置
|
||||
if conf.has_section('server'):
|
||||
try:
|
||||
port = conf.getint('server', 'port')
|
||||
if port < 1024 or port > 65535:
|
||||
warnings.append(f"端口号 {port} 可能不合适,建议使用1024-65535范围内的端口")
|
||||
except:
|
||||
issues.append("服务器端口配置无效")
|
||||
|
||||
# 检查目录权限
|
||||
if conf.has_section('file'):
|
||||
directories = ['inputdir', 'storedir', 'tmpdir']
|
||||
for dir_key in directories:
|
||||
dir_path = conf.get('file', dir_key, fallback='')
|
||||
if dir_path:
|
||||
abs_path = os.path.abspath(dir_path)
|
||||
parent_dir = os.path.dirname(abs_path)
|
||||
|
||||
if not os.path.exists(parent_dir):
|
||||
issues.append(f"父目录不存在: {parent_dir} (配置: {dir_key})")
|
||||
elif not os.access(parent_dir, os.W_OK):
|
||||
issues.append(f"没有写入权限: {parent_dir} (配置: {dir_key})")
|
||||
|
||||
# 检查数据库路径
|
||||
if conf.has_section('database'):
|
||||
db_path = conf.get('database', 'path', fallback='')
|
||||
if db_path:
|
||||
db_dir = os.path.dirname(os.path.abspath(db_path))
|
||||
if not os.path.exists(db_dir):
|
||||
issues.append(f"数据库目录不存在: {db_dir}")
|
||||
elif not os.access(db_dir, os.W_OK):
|
||||
issues.append(f"数据库目录没有写入权限: {db_dir}")
|
||||
|
||||
return {
|
||||
'valid': len(issues) == 0,
|
||||
'issues': issues,
|
||||
'warnings': warnings
|
||||
}
|
||||
|
||||
def print_validation_results(results: Dict[str, Any]):
|
||||
"""打印配置验证结果"""
|
||||
if results['valid']:
|
||||
print("✅ 配置验证通过")
|
||||
else:
|
||||
print("❌ 配置验证失败")
|
||||
print("\n严重问题:")
|
||||
for issue in results['issues']:
|
||||
print(f" • {issue}")
|
||||
|
||||
if results['warnings']:
|
||||
print("\n⚠️ 警告:")
|
||||
for warning in results['warnings']:
|
||||
print(f" • {warning}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 当直接运行此文件时,执行配置验证
|
||||
results = validate_config()
|
||||
print_validation_results(results)
|
||||
|
||||
if not results['valid']:
|
||||
sys.exit(1)
|
61
utils/db_pool.py
Normal file
61
utils/db_pool.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from queue import Queue, Empty
|
||||
import app_conf
|
||||
|
||||
class ConnectionPool:
|
||||
def __init__(self, database_path: str, max_connections: int = 10):
|
||||
self.database_path = database_path
|
||||
self.max_connections = max_connections
|
||||
self.pool = Queue(maxsize=max_connections)
|
||||
self.lock = threading.Lock()
|
||||
self._initialize_pool()
|
||||
|
||||
def _initialize_pool(self):
|
||||
"""初始化连接池"""
|
||||
for _ in range(self.max_connections):
|
||||
conn = sqlite3.connect(self.database_path, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row # 允许按列名访问
|
||||
self.pool.put(conn)
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""获取数据库连接的上下文管理器"""
|
||||
conn = None
|
||||
try:
|
||||
conn = self.pool.get(timeout=5) # 5秒超时
|
||||
yield conn
|
||||
except Empty:
|
||||
raise Exception("无法获取数据库连接:连接池已满")
|
||||
finally:
|
||||
if conn:
|
||||
self.pool.put(conn)
|
||||
|
||||
def close_all(self):
|
||||
"""关闭所有连接"""
|
||||
while not self.pool.empty():
|
||||
try:
|
||||
conn = self.pool.get_nowait()
|
||||
conn.close()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
# 全局连接池实例
|
||||
_pool = None
|
||||
_pool_lock = threading.Lock()
|
||||
|
||||
def get_pool():
|
||||
"""获取全局连接池实例"""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
with _pool_lock:
|
||||
if _pool is None:
|
||||
conf = app_conf.conf()
|
||||
database_path = conf.get("database", "path")
|
||||
_pool = ConnectionPool(database_path)
|
||||
return _pool
|
||||
|
||||
def get_connection():
|
||||
"""获取数据库连接"""
|
||||
return get_pool().get_connection()
|
59
utils/logger.py
Normal file
59
utils/logger.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import logging
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import os
|
||||
|
||||
def setup_logging(app=None, log_level=logging.INFO):
|
||||
"""
|
||||
设置应用程序的日志记录
|
||||
"""
|
||||
# 创建logs目录
|
||||
if not os.path.exists('logs'):
|
||||
os.makedirs('logs')
|
||||
|
||||
# 设置日志格式
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# 文件处理器 - 应用日志
|
||||
file_handler = RotatingFileHandler(
|
||||
'logs/app.log',
|
||||
maxBytes=10*1024*1024, # 10MB
|
||||
backupCount=5
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(log_level)
|
||||
|
||||
# 错误日志处理器
|
||||
error_handler = RotatingFileHandler(
|
||||
'logs/error.log',
|
||||
maxBytes=10*1024*1024, # 10MB
|
||||
backupCount=5
|
||||
)
|
||||
error_handler.setFormatter(formatter)
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(log_level)
|
||||
|
||||
# 配置根日志记录器
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(log_level)
|
||||
root_logger.addHandler(file_handler)
|
||||
root_logger.addHandler(error_handler)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 如果是Flask应用,也配置Flask的日志
|
||||
if app:
|
||||
app.logger.addHandler(file_handler)
|
||||
app.logger.addHandler(error_handler)
|
||||
app.logger.setLevel(log_level)
|
||||
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
def get_logger(name):
|
||||
"""获取指定名称的日志记录器"""
|
||||
return logging.getLogger(name)
|
173
utils/performance_monitor.py
Normal file
173
utils/performance_monitor.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
性能监控模块
|
||||
用于监控应用程序的性能指标
|
||||
"""
|
||||
import time
|
||||
import threading
|
||||
from functools import wraps
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceMetric:
|
||||
"""性能指标数据类"""
|
||||
name: str
|
||||
start_time: float
|
||||
end_time: float = 0
|
||||
duration: float = 0
|
||||
memory_before: float = 0
|
||||
memory_after: float = 0
|
||||
success: bool = True
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
class PerformanceMonitor:
|
||||
"""性能监控器"""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics: List[PerformanceMetric] = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def start_monitoring(self, name: str) -> PerformanceMetric:
|
||||
"""开始监控一个操作"""
|
||||
metric = PerformanceMetric(
|
||||
name=name,
|
||||
start_time=time.time(),
|
||||
memory_before=self.get_memory_usage()
|
||||
)
|
||||
return metric
|
||||
|
||||
def end_monitoring(self, metric: PerformanceMetric, success: bool = True, error_message: str = ""):
|
||||
"""结束监控操作"""
|
||||
metric.end_time = time.time()
|
||||
metric.duration = metric.end_time - metric.start_time
|
||||
metric.memory_after = self.get_memory_usage()
|
||||
metric.success = success
|
||||
metric.error_message = error_message
|
||||
|
||||
with self.lock:
|
||||
self.metrics.append(metric)
|
||||
|
||||
# 保持最近1000条记录
|
||||
if len(self.metrics) > 1000:
|
||||
self.metrics = self.metrics[-1000:]
|
||||
|
||||
logger.debug(f"性能监控: {metric.name} - 耗时: {metric.duration:.3f}s, "
|
||||
f"内存变化: {metric.memory_after - metric.memory_before:.2f}MB")
|
||||
|
||||
def get_memory_usage(self) -> float:
|
||||
"""获取当前内存使用量(MB)"""
|
||||
try:
|
||||
# 简单的内存使用量估算
|
||||
# 在Windows上,可以使用其他方法,这里先返回0
|
||||
return 0.0
|
||||
except:
|
||||
return 0.0
|
||||
|
||||
def get_stats(self, operation_name: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""获取性能统计信息"""
|
||||
with self.lock:
|
||||
filtered_metrics = self.metrics
|
||||
if operation_name:
|
||||
filtered_metrics = [m for m in self.metrics if m.name == operation_name]
|
||||
|
||||
if not filtered_metrics:
|
||||
return {}
|
||||
|
||||
durations = [m.duration for m in filtered_metrics if m.success]
|
||||
success_count = len([m for m in filtered_metrics if m.success])
|
||||
error_count = len([m for m in filtered_metrics if not m.success])
|
||||
|
||||
stats = {
|
||||
'operation_name': operation_name or 'All Operations',
|
||||
'total_calls': len(filtered_metrics),
|
||||
'success_calls': success_count,
|
||||
'error_calls': error_count,
|
||||
'success_rate': f"{(success_count / len(filtered_metrics) * 100):.2f}%" if filtered_metrics else "0%",
|
||||
'avg_duration': f"{(sum(durations) / len(durations)):.3f}s" if durations else "0s",
|
||||
'min_duration': f"{min(durations):.3f}s" if durations else "0s",
|
||||
'max_duration': f"{max(durations):.3f}s" if durations else "0s",
|
||||
'current_memory': f"{self.get_memory_usage():.2f}MB"
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def get_recent_errors(self, count: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取最近的错误"""
|
||||
with self.lock:
|
||||
error_metrics = [m for m in self.metrics if not m.success][-count:]
|
||||
return [
|
||||
{
|
||||
'name': m.name,
|
||||
'time': datetime.fromtimestamp(m.start_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'duration': f"{m.duration:.3f}s",
|
||||
'error': m.error_message
|
||||
}
|
||||
for m in error_metrics
|
||||
]
|
||||
|
||||
def clear_metrics(self):
|
||||
"""清空监控数据"""
|
||||
with self.lock:
|
||||
self.metrics.clear()
|
||||
logger.info("清空性能监控数据")
|
||||
|
||||
|
||||
# 全局性能监控器
|
||||
_performance_monitor = None
|
||||
|
||||
|
||||
def get_performance_monitor() -> PerformanceMonitor:
|
||||
"""获取全局性能监控器实例"""
|
||||
global _performance_monitor
|
||||
if _performance_monitor is None:
|
||||
_performance_monitor = PerformanceMonitor()
|
||||
return _performance_monitor
|
||||
|
||||
|
||||
def monitor_performance(operation_name: Optional[str] = None):
|
||||
"""性能监控装饰器"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
monitor = get_performance_monitor()
|
||||
name = operation_name or f"{func.__module__}.{func.__name__}"
|
||||
|
||||
metric = monitor.start_monitoring(name)
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
monitor.end_monitoring(metric, success=True)
|
||||
return result
|
||||
except Exception as e:
|
||||
monitor.end_monitoring(metric, success=False, error_message=str(e))
|
||||
raise
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def timing_context(operation_name: str):
|
||||
"""性能监控上下文管理器"""
|
||||
class TimingContext:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.monitor = get_performance_monitor()
|
||||
self.metric: Optional[PerformanceMetric] = None
|
||||
|
||||
def __enter__(self):
|
||||
self.metric = self.monitor.start_monitoring(self.name)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.metric:
|
||||
if exc_type is None:
|
||||
self.monitor.end_monitoring(self.metric, success=True)
|
||||
else:
|
||||
self.monitor.end_monitoring(self.metric, success=False, error_message=str(exc_val))
|
||||
|
||||
return TimingContext(operation_name)
|
34
utils/security.py
Normal file
34
utils/security.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
import hmac
|
||||
|
||||
from typing import Optional
|
||||
|
||||
def hash_password(password: str, salt: Optional[str] = None) -> tuple[str, str]:
|
||||
"""
|
||||
哈希密码并返回哈希值和盐值
|
||||
"""
|
||||
if salt is None:
|
||||
salt = secrets.token_hex(32)
|
||||
|
||||
password_hash = hashlib.pbkdf2_hmac(
|
||||
'sha256',
|
||||
password.encode('utf-8'),
|
||||
salt.encode('utf-8'),
|
||||
100000 # 迭代次数
|
||||
)
|
||||
|
||||
return password_hash.hex(), salt
|
||||
|
||||
def verify_password(password: str, hashed_password: str, salt: str) -> bool:
|
||||
"""
|
||||
验证密码是否正确
|
||||
"""
|
||||
test_hash, _ = hash_password(password, salt)
|
||||
return hmac.compare_digest(test_hash, hashed_password)
|
||||
|
||||
def generate_session_token() -> str:
|
||||
"""
|
||||
生成安全的会话令牌
|
||||
"""
|
||||
return secrets.token_urlsafe(32)
|
Reference in New Issue
Block a user