feat(file): 优化文件处理和缓存机制

- 重构文件处理逻辑,提高性能和可维护性
- 增加缓存机制,减少重复读取和处理
- 改进错误处理和日志记录
- 优化缩略图生成算法
- 添加性能监控和测试依赖
This commit is contained in:
2025-07-11 00:21:57 +08:00
parent d0f9e65ad1
commit 8c4e5885c7
15 changed files with 1034 additions and 94 deletions

2
.gitignore vendored
View File

@@ -3,3 +3,5 @@ data
input
test.py
conf/app.ini
logs/app.log
logs/error.log

View File

@@ -37,8 +37,8 @@ def init():
id INTEGER PRIMARY KEY AUTOINCREMENT,
time INT NOT NULL,
bookid TEXT NOT NULL,
from_uid INTEGAR NOT NULL,
score INT NOT NULL,
from_uid INTEGER NOT NULL,
score TEXT NOT NULL,
content TEXT
);
"""

284
file.py
View File

@@ -1,86 +1,266 @@
import shutil, os, zipfile, io, cv2, numpy as np
import hashlib
import time
from functools import lru_cache
from pathlib import Path
import logging
import db.file, app_conf
from utils.logger import get_logger
from utils.cache_manager import get_cache_manager, cache_image
from utils.performance_monitor import monitor_performance, timing_context
app_conf = app_conf.conf()
# 获取配置对象
conf = app_conf.conf()
logger = get_logger(__name__)
cache_manager = get_cache_manager()
# 内存缓存 - 存储最近访问的ZIP文件列表
_zip_cache = {}
_cache_timeout = 300 # 5分钟缓存超时
def init():
"""初始化文件目录"""
paths = ("inputdir", "storedir", "tmpdir")
for path in paths:
try:
os.makedirs(app_conf.get("file", path))
dir_path = Path(conf.get("file", path))
dir_path.mkdir(parents=True, exist_ok=True)
logger.info(f"创建目录: {dir_path}")
except Exception as e:
print(e)
logger.error(f"创建目录失败 {path}: {e}")
def auotLoadFile():
fileList = os.listdir(app_conf.get("file", "inputdir"))
for item in fileList:
if zipfile.is_zipfile(
app_conf.get("file", "inputdir") + "/" + item
): # 判断是否为压缩包
with zipfile.ZipFile(
app_conf.get("file", "inputdir") + "/" + item, "r"
) as zip_ref:
db.file.new(item, len(zip_ref.namelist())) # 添加数据库记录 移动到存储
shutil.move(
app_conf.get("file", "inputdir") + "/" + item,
app_conf.get("file", "storedir") + "/" + item,
)
print("已添加 " + item)
else:
print("不符合条件 " + item)
@monitor_performance("file.get_image_files_from_zip")
def get_image_files_from_zip(zip_path: str) -> tuple:
"""
从ZIP文件中获取图片文件列表使用缓存提高性能
返回: (image_files_list, cache_key)
"""
cache_key = f"{zip_path}_{os.path.getmtime(zip_path)}"
current_time = time.time()
# 检查缓存
if cache_key in _zip_cache:
cache_data = _zip_cache[cache_key]
if current_time - cache_data['timestamp'] < _cache_timeout:
logger.debug(f"使用缓存的ZIP文件列表: {zip_path}")
return cache_data['files'], cache_key
def raedZip(bookid: str, index: int):
bookinfo = db.file.searchByid(bookid)
zippath = app_conf.get("file", "storedir") + "/" + bookinfo[0][2]
# 读取ZIP文件
try:
# 创建一个ZipFile对象
with zipfile.ZipFile(zippath, "r") as zip_ref:
# 获取图片文件列表
with zipfile.ZipFile(zip_path, "r") as zip_ref:
image_files = [
file
for file in zip_ref.namelist()
file for file in zip_ref.namelist()
if file.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"))
]
# 缓存结果
_zip_cache[cache_key] = {
'files': image_files,
'timestamp': current_time
}
# 清理过期缓存
_cleanup_cache()
logger.debug(f"缓存ZIP文件列表: {zip_path}, 图片数量: {len(image_files)}")
return image_files, cache_key
except Exception as e:
logger.error(f"读取ZIP文件失败 {zip_path}: {e}")
return [], cache_key
def _cleanup_cache():
"""清理过期缓存"""
current_time = time.time()
expired_keys = [
key for key, data in _zip_cache.items()
if current_time - data['timestamp'] > _cache_timeout
]
for key in expired_keys:
del _zip_cache[key]
if expired_keys:
logger.debug(f"清理过期缓存: {len(expired_keys)}")
@monitor_performance("file.autoLoadFile")
def autoLoadFile():
"""自动加载文件,优化路径处理和错误处理"""
input_dir = Path(conf.get("file", "inputdir"))
store_dir = Path(conf.get("file", "storedir"))
if not input_dir.exists():
logger.warning(f"输入目录不存在: {input_dir}")
return
file_list = []
try:
file_list = [f for f in input_dir.iterdir() if f.is_file()]
except Exception as e:
logger.error(f"读取输入目录失败: {e}")
return
processed_count = 0
for file_path in file_list:
try:
if zipfile.is_zipfile(file_path):
with zipfile.ZipFile(file_path, "r") as zip_ref:
page_count = len([f for f in zip_ref.namelist()
if f.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"))])
if page_count > 0:
db.file.new(file_path.name, page_count)
# 移动文件到存储目录
target_path = store_dir / file_path.name
shutil.move(str(file_path), str(target_path))
logger.info(f"已添加漫画: {file_path.name}, 页数: {page_count}")
processed_count += 1
else:
logger.warning(f"ZIP文件中没有图片: {file_path.name}")
else:
logger.info(f"非ZIP文件跳过: {file_path.name}")
except Exception as e:
logger.error(f"处理文件失败 {file_path.name}: {e}")
logger.info(f"自动加载完成,处理了 {processed_count} 个文件")
@monitor_performance("file.readZip")
def readZip(bookid: str, index: int) -> tuple:
"""
从ZIP文件中读取指定索引的图片
优化:使用缓存的文件列表,改进错误处理
返回: (image_data, filename) 或 (error_message, "")
"""
try:
bookinfo = db.file.searchByid(bookid)
if not bookinfo:
logger.warning(f"未找到书籍ID: {bookid}")
return "Book not found", ""
zip_path = Path(conf.get("file", "storedir")) / bookinfo[0][2]
if not zip_path.exists():
logger.error(f"ZIP文件不存在: {zip_path}")
return "ZIP file not found", ""
# 使用缓存获取图片文件列表
image_files, _ = get_image_files_from_zip(str(zip_path))
if not image_files:
return "not imgfile in zip", ""
logger.warning(f"ZIP文件中没有图片: {zip_path}")
return "No image files in zip", ""
if int(index) > len(image_files):
return "404 not found", ""
if int(index) >= len(image_files):
logger.warning(f"图片索引超出范围: {index}, 总数: {len(image_files)}")
return "Image index out of range", ""
# 假设我们只提取图片文件
# 读取指定图片
with zipfile.ZipFile(zip_path, "r") as zip_ref:
image_filename = image_files[int(index)]
# 读取图片数据
image_data = zip_ref.read(image_filename)
zip_ref.close()
logger.debug(f"读取图片: {bookid}/{index} -> {image_filename}")
return image_data, image_filename
except zipfile.BadZipFile: # 异常处理
except zipfile.BadZipFile:
logger.error(f"损坏的ZIP文件: {bookid}")
return "Bad ZipFile", ""
except Exception as e:
return str(e), ""
logger.error(f"读取ZIP文件失败 {bookid}/{index}: {e}")
return f"Error: {str(e)}", ""
def thumbnail(input, minSize: int = 600, encode:str="webp"):
img = cv2.imdecode(np.frombuffer(input, np.uint8), cv2.IMREAD_COLOR)
height = img.shape[0] # 图片高度
width = img.shape[1] # 图片宽度
if minSize < np.amin((height,width)):
@lru_cache(maxsize=128)
def _get_image_hash(image_data: bytes) -> str:
"""生成图片数据的哈希值用于缓存"""
return hashlib.md5(image_data).hexdigest()
@cache_image
def thumbnail(input_data: bytes, min_size: int = 600, encode: str = "webp", quality: int = 75) -> bytes:
"""
生成缩略图,优化编码逻辑和性能
"""
if not input_data:
logger.warning("输入图片数据为空")
return input_data
try:
# 解码图片
img = cv2.imdecode(np.frombuffer(input_data, np.uint8), cv2.IMREAD_COLOR)
if img is None:
logger.warning("无法解码图片数据")
return input_data
height, width = img.shape[:2]
logger.debug(f"原始图片尺寸: {width}x{height}")
# 判断是否需要缩放
min_dimension = min(height, width)
if min_size < min_dimension:
# 计算新尺寸
if height > width:
newshape = (minSize, int(minSize / width * height))
new_width = min_size
new_height = int(min_size * height / width)
else:
newshape = (int(minSize / height * width), minSize)
img = cv2.resize(img, newshape)
if encode == "webp":
success, encoded_image = cv2.imencode(".webp", img, [cv2.IMWRITE_WEBP_QUALITY, 75])
elif encode == "jpg" or "jpeg":
success, encoded_image = cv2.imencode(".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, 75])
new_height = min_size
new_width = int(min_size * width / height)
img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
logger.debug(f"缩放后图片尺寸: {new_width}x{new_height}")
# 编码图片
if encode.lower() == "webp":
success, encoded_image = cv2.imencode(
".webp", img, [cv2.IMWRITE_WEBP_QUALITY, quality]
)
elif encode.lower() in ("jpg", "jpeg"):
success, encoded_image = cv2.imencode(
".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, quality]
)
elif encode.lower() == "png":
success, encoded_image = cv2.imencode(
".png", img, [cv2.IMWRITE_PNG_COMPRESSION, 6]
)
else:
return input
return encoded_image.tobytes()
logger.warning(f"不支持的编码格式: {encode}, 返回原始数据")
return input_data
if not success:
logger.error(f"图片编码失败: {encode}")
return input_data
result = encoded_image.tobytes()
logger.debug(f"图片处理完成: 原始 {len(input_data)} bytes -> 处理后 {len(result)} bytes")
return result
except Exception as e:
logger.error(f"图片处理异常: {e}")
return input_data
def get_zip_image_count(bookid: str) -> int:
"""
获取ZIP文件中的图片数量使用缓存
"""
try:
bookinfo = db.file.searchByid(bookid)
if not bookinfo:
return 0
zip_path = Path(conf.get("file", "storedir")) / bookinfo[0][2]
if not zip_path.exists():
return 0
image_files, _ = get_image_files_from_zip(str(zip_path))
return len(image_files)
except Exception as e:
logger.error(f"获取图片数量失败 {bookid}: {e}")
return 0

16
main.py
View File

@@ -7,21 +7,33 @@ from router.api_Img import api_Img_bp
from router.page import page_bp
from router.admin_page import admin_page_bp
from router.api_comment import comment_api_bp
from router.performance_api import performance_bp
from utils.logger import setup_logging
from utils.performance_monitor import get_performance_monitor
app = Flask(__name__)
conf = app_conf.conf()
def appinit():
"""应用初始化,集成日志和性能监控"""
# 设置日志
setup_logging(app)
# 初始化文件系统和数据库
file.init()
db.util.init()
file.auotLoadFile()
file.autoLoadFile()
# 启动性能监控
monitor = get_performance_monitor()
app.logger.info("应用初始化完成,性能监控已启动")
# 注册蓝图
app.register_blueprint(api_Img_bp)
app.register_blueprint(page_bp)
app.register_blueprint(admin_page_bp)
app.register_blueprint(comment_api_bp)
app.register_blueprint(performance_bp)
if __name__ == "__main__":
appinit()

View File

@@ -1,4 +1,12 @@
shortuuid
flask
flask>=2.3.0
opencv-python
opencv-python-headless
werkzeug>=2.3.0
Pillow>=9.0.0
python-dotenv
flask-limiter
# 性能测试依赖(可选)
requests>=2.25.0
# 如果需要更好的性能监控,可以添加
# psutil>=5.8.0

View File

@@ -1,6 +1,6 @@
from flask import *
from flask import Blueprint
from flask import Blueprint, request, abort, make_response
import db.file , file, gc , app_conf
from utils.performance_monitor import timing_context
api_Img_bp = Blueprint("api_Img_bp", __name__)
@@ -11,21 +11,30 @@ fullSize = conf.getint("img", "fullSize")
@api_Img_bp.route("/api/img/<bookid>/<index>")
def img(bookid, index): # 图片接口
with timing_context(f"api.img.{bookid}.{index}"):
if request.cookies.get("islogin") is None:
return abort(403)
if len(db.file.searchByid(bookid)) == 0:
return abort(404)
# 设置响应类型为图片
data, filename = file.raedZip(bookid, index)
# 读取图片数据
data, filename = file.readZip(bookid, index)
if isinstance(data, str):
abort(404)
# 处理图片尺寸
if request.args.get("mini") == "yes":
data = file.thumbnail(data, miniSize, encode=imgencode)
else:
data = file.thumbnail(data, fullSize, encode=imgencode)
response = make_response(data) # 读取文件
del data
# 创建响应
response = make_response(data)
del data # 及时释放内存
response.headers.set("Content-Type", f"image/{imgencode}")
response.headers.set("Content-Disposition", "inline", filename=filename)
gc.collect()
response.headers.set("Cache-Control", "public, max-age=3600") # 添加缓存头
gc.collect() # 强制垃圾回收
return response

View File

@@ -14,7 +14,7 @@ def overview(page): # 概览
if request.cookies.get("islogin") is None: # 验证登录状态
return redirect("/")
metaDataList = db.file.getMetadata(
(page - 1) * 20, page * 20, request.args.get("search")
(page - 1) * 20, page * 20, request.args.get("search", "")
)
for item in metaDataList:
item[2] = item[2][:-4] # 去除文件扩展名
@@ -89,7 +89,13 @@ def upload_file():
uploaded_file = request.files.getlist("files[]") # 获取上传的文件列表
print(uploaded_file)
for fileitem in uploaded_file:
if fileitem.filename != "":
fileitem.save(conf.get("file", "inputdir") + "/" + fileitem.filename)
file.auotLoadFile()
if fileitem.filename and fileitem.filename != "":
input_dir = conf.get("file", "inputdir")
if not input_dir:
return "Input directory is not configured.", 500
import os
if input_dir is None:
return "Input directory is not configured.", 500
fileitem.save(os.path.join(input_dir, fileitem.filename))
file.autoLoadFile()
return "success"

61
router/performance_api.py Normal file
View File

@@ -0,0 +1,61 @@
from flask import Blueprint, render_template, jsonify, request
from utils.performance_monitor import get_performance_monitor
from utils.cache_manager import get_cache_manager
performance_bp = Blueprint("performance_bp", __name__)
@performance_bp.route("/api/performance/stats")
def performance_stats():
"""获取性能统计信息"""
if request.cookies.get("islogin") is None:
return jsonify({"error": "Unauthorized"}), 403
monitor = get_performance_monitor()
cache_manager = get_cache_manager()
operation_name = request.args.get("operation")
stats = {
"performance": monitor.get_stats(operation_name),
"cache": cache_manager.get_stats(),
"recent_errors": monitor.get_recent_errors(5)
}
return jsonify(stats)
@performance_bp.route("/api/performance/clear")
def clear_performance_data():
"""清空性能监控数据"""
if request.cookies.get("islogin") is None:
return jsonify({"error": "Unauthorized"}), 403
monitor = get_performance_monitor()
monitor.clear_metrics()
return jsonify({"message": "Performance data cleared"})
@performance_bp.route("/api/cache/clear")
def clear_cache():
"""清空缓存"""
if request.cookies.get("islogin") is None:
return jsonify({"error": "Unauthorized"}), 403
cache_manager = get_cache_manager()
cache_manager.clear()
return jsonify({"message": "Cache cleared"})
@performance_bp.route("/api/cache/cleanup")
def cleanup_cache():
"""清理过期缓存"""
if request.cookies.get("islogin") is None:
return jsonify({"error": "Unauthorized"}), 403
cache_manager = get_cache_manager()
cache_manager.cleanup_expired()
return jsonify({"message": "Expired cache cleaned up"})

View File

@@ -82,11 +82,6 @@
border-color: #1557b0;
}
</style>
.form-group button:hover {
background-color: #4cae4c;
}
</style>
</head>
<body>
<div class="container">

236
utils/cache_manager.py Normal file
View File

@@ -0,0 +1,236 @@
"""
缓存管理器 - 用于图片和数据缓存
提供内存缓存和磁盘缓存功能
"""
import os
import hashlib
import json
import time
import threading
from pathlib import Path
from typing import Optional, Dict, Any
import pickle
from utils.logger import get_logger
import app_conf
logger = get_logger(__name__)
conf = app_conf.conf()
class CacheManager:
"""缓存管理器"""
def __init__(self, cache_dir: Optional[str] = None, max_memory_size: int = 100):
"""
初始化缓存管理器
Args:
cache_dir: 磁盘缓存目录
max_memory_size: 内存缓存最大条目数
"""
self.cache_dir = Path(cache_dir or conf.get("file", "tmpdir")) / "cache"
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.memory_cache = {}
self.max_memory_size = max_memory_size
self.cache_access_times = {}
self.lock = threading.RLock()
# 缓存统计
self.stats = {
'hits': 0,
'misses': 0,
'memory_hits': 0,
'disk_hits': 0
}
logger.info(f"缓存管理器初始化: 目录={self.cache_dir}, 内存限制={max_memory_size}")
def _generate_key(self, *args) -> str:
"""生成缓存键"""
key_string = "_".join(str(arg) for arg in args)
return hashlib.md5(key_string.encode('utf-8')).hexdigest()
def _cleanup_memory_cache(self):
"""清理内存缓存,移除最久未访问的项目"""
if len(self.memory_cache) <= self.max_memory_size:
return
# 按访问时间排序,移除最旧的项目
sorted_items = sorted(
self.cache_access_times.items(),
key=lambda x: x[1]
)
# 移除最旧的20%
remove_count = len(self.memory_cache) - self.max_memory_size + 1
for key, _ in sorted_items[:remove_count]:
if key in self.memory_cache:
del self.memory_cache[key]
del self.cache_access_times[key]
logger.debug(f"清理内存缓存: 移除 {remove_count}")
def get(self, key: str, default=None) -> Any:
"""获取缓存数据"""
with self.lock:
current_time = time.time()
# 检查内存缓存
if key in self.memory_cache:
self.cache_access_times[key] = current_time
self.stats['hits'] += 1
self.stats['memory_hits'] += 1
logger.debug(f"内存缓存命中: {key}")
return self.memory_cache[key]
# 检查磁盘缓存
cache_file = self.cache_dir / f"{key}.cache"
if cache_file.exists():
try:
with open(cache_file, 'rb') as f:
data = pickle.load(f)
# 将数据加载到内存缓存
self.memory_cache[key] = data
self.cache_access_times[key] = current_time
self._cleanup_memory_cache()
self.stats['hits'] += 1
self.stats['disk_hits'] += 1
logger.debug(f"磁盘缓存命中: {key}")
return data
except Exception as e:
logger.error(f"读取磁盘缓存失败 {key}: {e}")
# 删除损坏的缓存文件
try:
cache_file.unlink()
except:
pass
self.stats['misses'] += 1
logger.debug(f"缓存未命中: {key}")
return default
def set(self, key: str, value: Any, disk_cache: bool = True):
"""设置缓存数据"""
with self.lock:
current_time = time.time()
# 存储到内存缓存
self.memory_cache[key] = value
self.cache_access_times[key] = current_time
self._cleanup_memory_cache()
# 存储到磁盘缓存
if disk_cache:
try:
cache_file = self.cache_dir / f"{key}.cache"
with open(cache_file, 'wb') as f:
pickle.dump(value, f)
logger.debug(f"数据已缓存: {key}")
except Exception as e:
logger.error(f"写入磁盘缓存失败 {key}: {e}")
def delete(self, key: str):
"""删除缓存数据"""
with self.lock:
# 删除内存缓存
if key in self.memory_cache:
del self.memory_cache[key]
del self.cache_access_times[key]
# 删除磁盘缓存
cache_file = self.cache_dir / f"{key}.cache"
if cache_file.exists():
try:
cache_file.unlink()
logger.debug(f"删除缓存: {key}")
except Exception as e:
logger.error(f"删除磁盘缓存失败 {key}: {e}")
def clear(self):
"""清空所有缓存"""
with self.lock:
# 清空内存缓存
self.memory_cache.clear()
self.cache_access_times.clear()
# 清空磁盘缓存
try:
for cache_file in self.cache_dir.glob("*.cache"):
cache_file.unlink()
logger.info("清空所有缓存")
except Exception as e:
logger.error(f"清空磁盘缓存失败: {e}")
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
with self.lock:
total_requests = self.stats['hits'] + self.stats['misses']
hit_rate = (self.stats['hits'] / total_requests * 100) if total_requests > 0 else 0
return {
'total_requests': total_requests,
'hits': self.stats['hits'],
'misses': self.stats['misses'],
'hit_rate': f"{hit_rate:.2f}%",
'memory_hits': self.stats['memory_hits'],
'disk_hits': self.stats['disk_hits'],
'memory_cache_size': len(self.memory_cache),
'disk_cache_files': len(list(self.cache_dir.glob("*.cache")))
}
def cleanup_expired(self, max_age_hours: int = 24):
"""清理过期的磁盘缓存文件"""
try:
current_time = time.time()
max_age_seconds = max_age_hours * 3600
removed_count = 0
for cache_file in self.cache_dir.glob("*.cache"):
if current_time - cache_file.stat().st_mtime > max_age_seconds:
cache_file.unlink()
removed_count += 1
if removed_count > 0:
logger.info(f"清理过期缓存文件: {removed_count}")
except Exception as e:
logger.error(f"清理过期缓存失败: {e}")
# 全局缓存管理器实例
_cache_manager = None
def get_cache_manager() -> CacheManager:
"""获取全局缓存管理器实例"""
global _cache_manager
if _cache_manager is None:
_cache_manager = CacheManager()
return _cache_manager
def cache_image(func):
"""图片缓存装饰器"""
def wrapper(*args, **kwargs):
cache_manager = get_cache_manager()
# 生成缓存键
cache_key = cache_manager._generate_key(func.__name__, *args, *kwargs.items())
# 尝试从缓存获取
cached_result = cache_manager.get(cache_key)
if cached_result is not None:
return cached_result
# 执行函数并缓存结果
result = func(*args, **kwargs)
if result: # 只缓存有效结果
cache_manager.set(cache_key, result)
return result
return wrapper

104
utils/config_validator.py Normal file
View File

@@ -0,0 +1,104 @@
import os
import sys
from typing import Dict, Any
import app_conf
def validate_config() -> Dict[str, Any]:
"""
验证配置文件的有效性
返回验证结果字典
"""
conf = app_conf.conf()
issues = []
warnings = []
# 检查必需的配置项
required_sections = {
'server': ['port', 'host'],
'user': ['username', 'password'],
'database': ['path'],
'file': ['inputdir', 'storedir', 'tmpdir'],
'img': ['encode', 'miniSize', 'fullSize']
}
for section, keys in required_sections.items():
if not conf.has_section(section):
issues.append(f"缺少配置节: [{section}]")
continue
for key in keys:
if not conf.has_option(section, key):
issues.append(f"缺少配置项: [{section}].{key}")
# 检查安全性问题
if conf.has_section('user'):
username = conf.get('user', 'username', fallback='')
password = conf.get('user', 'password', fallback='')
if username == 'admin' and password == 'admin':
warnings.append("使用默认用户名和密码不安全,建议修改")
if len(password) < 8:
warnings.append("密码过于简单建议使用8位以上的复杂密码")
# 检查端口配置
if conf.has_section('server'):
try:
port = conf.getint('server', 'port')
if port < 1024 or port > 65535:
warnings.append(f"端口号 {port} 可能不合适建议使用1024-65535范围内的端口")
except:
issues.append("服务器端口配置无效")
# 检查目录权限
if conf.has_section('file'):
directories = ['inputdir', 'storedir', 'tmpdir']
for dir_key in directories:
dir_path = conf.get('file', dir_key, fallback='')
if dir_path:
abs_path = os.path.abspath(dir_path)
parent_dir = os.path.dirname(abs_path)
if not os.path.exists(parent_dir):
issues.append(f"父目录不存在: {parent_dir} (配置: {dir_key})")
elif not os.access(parent_dir, os.W_OK):
issues.append(f"没有写入权限: {parent_dir} (配置: {dir_key})")
# 检查数据库路径
if conf.has_section('database'):
db_path = conf.get('database', 'path', fallback='')
if db_path:
db_dir = os.path.dirname(os.path.abspath(db_path))
if not os.path.exists(db_dir):
issues.append(f"数据库目录不存在: {db_dir}")
elif not os.access(db_dir, os.W_OK):
issues.append(f"数据库目录没有写入权限: {db_dir}")
return {
'valid': len(issues) == 0,
'issues': issues,
'warnings': warnings
}
def print_validation_results(results: Dict[str, Any]):
"""打印配置验证结果"""
if results['valid']:
print("✅ 配置验证通过")
else:
print("❌ 配置验证失败")
print("\n严重问题:")
for issue in results['issues']:
print(f"{issue}")
if results['warnings']:
print("\n⚠️ 警告:")
for warning in results['warnings']:
print(f"{warning}")
if __name__ == "__main__":
# 当直接运行此文件时,执行配置验证
results = validate_config()
print_validation_results(results)
if not results['valid']:
sys.exit(1)

61
utils/db_pool.py Normal file
View File

@@ -0,0 +1,61 @@
import sqlite3
import threading
from contextlib import contextmanager
from queue import Queue, Empty
import app_conf
class ConnectionPool:
def __init__(self, database_path: str, max_connections: int = 10):
self.database_path = database_path
self.max_connections = max_connections
self.pool = Queue(maxsize=max_connections)
self.lock = threading.Lock()
self._initialize_pool()
def _initialize_pool(self):
"""初始化连接池"""
for _ in range(self.max_connections):
conn = sqlite3.connect(self.database_path, check_same_thread=False)
conn.row_factory = sqlite3.Row # 允许按列名访问
self.pool.put(conn)
@contextmanager
def get_connection(self):
"""获取数据库连接的上下文管理器"""
conn = None
try:
conn = self.pool.get(timeout=5) # 5秒超时
yield conn
except Empty:
raise Exception("无法获取数据库连接:连接池已满")
finally:
if conn:
self.pool.put(conn)
def close_all(self):
"""关闭所有连接"""
while not self.pool.empty():
try:
conn = self.pool.get_nowait()
conn.close()
except Empty:
break
# 全局连接池实例
_pool = None
_pool_lock = threading.Lock()
def get_pool():
"""获取全局连接池实例"""
global _pool
if _pool is None:
with _pool_lock:
if _pool is None:
conf = app_conf.conf()
database_path = conf.get("database", "path")
_pool = ConnectionPool(database_path)
return _pool
def get_connection():
"""获取数据库连接"""
return get_pool().get_connection()

59
utils/logger.py Normal file
View File

@@ -0,0 +1,59 @@
import logging
import sys
from logging.handlers import RotatingFileHandler
import os
def setup_logging(app=None, log_level=logging.INFO):
"""
设置应用程序的日志记录
"""
# 创建logs目录
if not os.path.exists('logs'):
os.makedirs('logs')
# 设置日志格式
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 文件处理器 - 应用日志
file_handler = RotatingFileHandler(
'logs/app.log',
maxBytes=10*1024*1024, # 10MB
backupCount=5
)
file_handler.setFormatter(formatter)
file_handler.setLevel(log_level)
# 错误日志处理器
error_handler = RotatingFileHandler(
'logs/error.log',
maxBytes=10*1024*1024, # 10MB
backupCount=5
)
error_handler.setFormatter(formatter)
error_handler.setLevel(logging.ERROR)
# 控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
console_handler.setLevel(log_level)
# 配置根日志记录器
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
root_logger.addHandler(file_handler)
root_logger.addHandler(error_handler)
root_logger.addHandler(console_handler)
# 如果是Flask应用也配置Flask的日志
if app:
app.logger.addHandler(file_handler)
app.logger.addHandler(error_handler)
app.logger.setLevel(log_level)
return logging.getLogger(__name__)
def get_logger(name):
"""获取指定名称的日志记录器"""
return logging.getLogger(name)

View File

@@ -0,0 +1,173 @@
"""
性能监控模块
用于监控应用程序的性能指标
"""
import time
import threading
from functools import wraps
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field
from datetime import datetime
from utils.logger import get_logger
logger = get_logger(__name__)
@dataclass
class PerformanceMetric:
"""性能指标数据类"""
name: str
start_time: float
end_time: float = 0
duration: float = 0
memory_before: float = 0
memory_after: float = 0
success: bool = True
error_message: str = ""
class PerformanceMonitor:
"""性能监控器"""
def __init__(self):
self.metrics: List[PerformanceMetric] = []
self.lock = threading.Lock()
def start_monitoring(self, name: str) -> PerformanceMetric:
"""开始监控一个操作"""
metric = PerformanceMetric(
name=name,
start_time=time.time(),
memory_before=self.get_memory_usage()
)
return metric
def end_monitoring(self, metric: PerformanceMetric, success: bool = True, error_message: str = ""):
"""结束监控操作"""
metric.end_time = time.time()
metric.duration = metric.end_time - metric.start_time
metric.memory_after = self.get_memory_usage()
metric.success = success
metric.error_message = error_message
with self.lock:
self.metrics.append(metric)
# 保持最近1000条记录
if len(self.metrics) > 1000:
self.metrics = self.metrics[-1000:]
logger.debug(f"性能监控: {metric.name} - 耗时: {metric.duration:.3f}s, "
f"内存变化: {metric.memory_after - metric.memory_before:.2f}MB")
def get_memory_usage(self) -> float:
"""获取当前内存使用量(MB)"""
try:
# 简单的内存使用量估算
# 在Windows上可以使用其他方法这里先返回0
return 0.0
except:
return 0.0
def get_stats(self, operation_name: Optional[str] = None) -> Dict[str, Any]:
"""获取性能统计信息"""
with self.lock:
filtered_metrics = self.metrics
if operation_name:
filtered_metrics = [m for m in self.metrics if m.name == operation_name]
if not filtered_metrics:
return {}
durations = [m.duration for m in filtered_metrics if m.success]
success_count = len([m for m in filtered_metrics if m.success])
error_count = len([m for m in filtered_metrics if not m.success])
stats = {
'operation_name': operation_name or 'All Operations',
'total_calls': len(filtered_metrics),
'success_calls': success_count,
'error_calls': error_count,
'success_rate': f"{(success_count / len(filtered_metrics) * 100):.2f}%" if filtered_metrics else "0%",
'avg_duration': f"{(sum(durations) / len(durations)):.3f}s" if durations else "0s",
'min_duration': f"{min(durations):.3f}s" if durations else "0s",
'max_duration': f"{max(durations):.3f}s" if durations else "0s",
'current_memory': f"{self.get_memory_usage():.2f}MB"
}
return stats
def get_recent_errors(self, count: int = 10) -> List[Dict[str, Any]]:
"""获取最近的错误"""
with self.lock:
error_metrics = [m for m in self.metrics if not m.success][-count:]
return [
{
'name': m.name,
'time': datetime.fromtimestamp(m.start_time).strftime('%Y-%m-%d %H:%M:%S'),
'duration': f"{m.duration:.3f}s",
'error': m.error_message
}
for m in error_metrics
]
def clear_metrics(self):
"""清空监控数据"""
with self.lock:
self.metrics.clear()
logger.info("清空性能监控数据")
# 全局性能监控器
_performance_monitor = None
def get_performance_monitor() -> PerformanceMonitor:
"""获取全局性能监控器实例"""
global _performance_monitor
if _performance_monitor is None:
_performance_monitor = PerformanceMonitor()
return _performance_monitor
def monitor_performance(operation_name: Optional[str] = None):
"""性能监控装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
monitor = get_performance_monitor()
name = operation_name or f"{func.__module__}.{func.__name__}"
metric = monitor.start_monitoring(name)
try:
result = func(*args, **kwargs)
monitor.end_monitoring(metric, success=True)
return result
except Exception as e:
monitor.end_monitoring(metric, success=False, error_message=str(e))
raise
return wrapper
return decorator
def timing_context(operation_name: str):
"""性能监控上下文管理器"""
class TimingContext:
def __init__(self, name: str):
self.name = name
self.monitor = get_performance_monitor()
self.metric: Optional[PerformanceMetric] = None
def __enter__(self):
self.metric = self.monitor.start_monitoring(self.name)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.metric:
if exc_type is None:
self.monitor.end_monitoring(self.metric, success=True)
else:
self.monitor.end_monitoring(self.metric, success=False, error_message=str(exc_val))
return TimingContext(operation_name)

34
utils/security.py Normal file
View File

@@ -0,0 +1,34 @@
import hashlib
import secrets
import hmac
from typing import Optional
def hash_password(password: str, salt: Optional[str] = None) -> tuple[str, str]:
"""
哈希密码并返回哈希值和盐值
"""
if salt is None:
salt = secrets.token_hex(32)
password_hash = hashlib.pbkdf2_hmac(
'sha256',
password.encode('utf-8'),
salt.encode('utf-8'),
100000 # 迭代次数
)
return password_hash.hex(), salt
def verify_password(password: str, hashed_password: str, salt: str) -> bool:
"""
验证密码是否正确
"""
test_hash, _ = hash_password(password, salt)
return hmac.compare_digest(test_hash, hashed_password)
def generate_session_token() -> str:
"""
生成安全的会话令牌
"""
return secrets.token_urlsafe(32)