Files
WPSBot/core/database.py
2025-10-29 12:36:20 +08:00

525 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""SQLite数据库操作模块 - 使用标准库sqlite3"""
import sqlite3
import json
import time
import logging
from typing import Optional, Dict, Any, List
from pathlib import Path
from config import DATABASE_PATH
logger = logging.getLogger(__name__)
class Database:
"""数据库管理类"""
def __init__(self, db_path: str = DATABASE_PATH):
"""初始化数据库连接
Args:
db_path: 数据库文件路径
"""
self.db_path = db_path
self._conn: Optional[sqlite3.Connection] = None
self._ensure_db_exists()
self.init_tables()
def _ensure_db_exists(self):
"""确保数据库目录存在"""
db_dir = Path(self.db_path).parent
db_dir.mkdir(parents=True, exist_ok=True)
@property
def conn(self) -> sqlite3.Connection:
"""获取数据库连接(懒加载)"""
if self._conn is None:
try:
self._conn = sqlite3.connect(
self.db_path,
check_same_thread=False, # 允许多线程访问
isolation_level=None, # 自动提交
timeout=30.0 # 增加超时时间
)
self._conn.row_factory = sqlite3.Row # 支持字典式访问
# 启用WAL模式以提高并发性能
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA synchronous=NORMAL")
self._conn.execute("PRAGMA cache_size=1000")
self._conn.execute("PRAGMA temp_store=MEMORY")
logger.info(f"数据库连接成功: {self.db_path}")
except Exception as e:
logger.error(f"数据库连接失败: {e}", exc_info=True)
raise
return self._conn
def init_tables(self):
"""初始化数据库表"""
cursor = self.conn.cursor()
# 用户表
cursor.execute("""
CREATE TABLE IF NOT EXISTS users (
user_id INTEGER PRIMARY KEY,
username TEXT,
created_at INTEGER NOT NULL,
last_active INTEGER NOT NULL
)
""")
# 游戏状态表
cursor.execute("""
CREATE TABLE IF NOT EXISTS game_states (
id INTEGER PRIMARY KEY AUTOINCREMENT,
chat_id INTEGER NOT NULL,
user_id INTEGER NOT NULL,
game_type TEXT NOT NULL,
state_data TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
UNIQUE(chat_id, user_id, game_type)
)
""")
# 创建索引
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_chat_user
ON game_states(chat_id, user_id)
""")
# 游戏统计表
cursor.execute("""
CREATE TABLE IF NOT EXISTS game_stats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
game_type TEXT NOT NULL,
wins INTEGER DEFAULT 0,
losses INTEGER DEFAULT 0,
draws INTEGER DEFAULT 0,
total_plays INTEGER DEFAULT 0,
UNIQUE(user_id, game_type)
)
""")
# 用户积分表 - 简化版本,只保留必要字段
cursor.execute("""
CREATE TABLE IF NOT EXISTS user_points (
user_id INTEGER PRIMARY KEY,
points INTEGER DEFAULT 0,
last_checkin_date TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
FOREIGN KEY (user_id) REFERENCES users (user_id)
)
""")
logger.info("数据库表初始化完成")
# ===== 用户相关操作 =====
def get_or_create_user(self, user_id: int, username: str = None) -> Dict:
"""获取或创建用户
Args:
user_id: 用户ID
username: 用户名
Returns:
用户信息字典
"""
cursor = self.conn.cursor()
current_time = int(time.time())
# 尝试获取用户
cursor.execute(
"SELECT * FROM users WHERE user_id = ?",
(user_id,)
)
user = cursor.fetchone()
if user:
# 更新最后活跃时间
cursor.execute(
"UPDATE users SET last_active = ? WHERE user_id = ?",
(current_time, user_id)
)
return dict(user)
else:
# 创建新用户
cursor.execute(
"INSERT INTO users (user_id, username, created_at, last_active) VALUES (?, ?, ?, ?)",
(user_id, username, current_time, current_time)
)
return {
'user_id': user_id,
'username': username,
'created_at': current_time,
'last_active': current_time
}
# ===== 游戏状态相关操作 =====
def get_game_state(self, chat_id: int, user_id: int, game_type: str) -> Optional[Dict]:
"""获取游戏状态
Args:
chat_id: 会话ID
user_id: 用户ID
game_type: 游戏类型
Returns:
游戏状态字典如果不存在返回None
"""
cursor = self.conn.cursor()
cursor.execute(
"SELECT * FROM game_states WHERE chat_id = ? AND user_id = ? AND game_type = ?",
(chat_id, user_id, game_type)
)
row = cursor.fetchone()
if row:
state = dict(row)
# 解析JSON数据
if state.get('state_data'):
state['state_data'] = json.loads(state['state_data'])
return state
return None
def save_game_state(self, chat_id: int, user_id: int, game_type: str, state_data: Dict):
"""保存游戏状态
Args:
chat_id: 会话ID
user_id: 用户ID
game_type: 游戏类型
state_data: 状态数据字典
"""
cursor = self.conn.cursor()
current_time = int(time.time())
state_json = json.dumps(state_data, ensure_ascii=False)
cursor.execute("""
INSERT INTO game_states (chat_id, user_id, game_type, state_data, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(chat_id, user_id, game_type)
DO UPDATE SET state_data = ?, updated_at = ?
""", (chat_id, user_id, game_type, state_json, current_time, current_time,
state_json, current_time))
logger.debug(f"保存游戏状态: chat_id={chat_id}, user_id={user_id}, game_type={game_type}")
def delete_game_state(self, chat_id: int, user_id: int, game_type: str):
"""删除游戏状态
Args:
chat_id: 会话ID
user_id: 用户ID
game_type: 游戏类型
"""
cursor = self.conn.cursor()
cursor.execute(
"DELETE FROM game_states WHERE chat_id = ? AND user_id = ? AND game_type = ?",
(chat_id, user_id, game_type)
)
logger.debug(f"删除游戏状态: chat_id={chat_id}, user_id={user_id}, game_type={game_type}")
def cleanup_old_sessions(self, timeout: int = 1800):
"""清理过期的游戏会话
Args:
timeout: 超时时间(秒)
"""
cursor = self.conn.cursor()
cutoff_time = int(time.time()) - timeout
cursor.execute(
"DELETE FROM game_states WHERE updated_at < ?",
(cutoff_time,)
)
deleted = cursor.rowcount
if deleted > 0:
logger.info(f"清理了 {deleted} 个过期游戏会话")
# ===== 游戏统计相关操作 =====
def get_game_stats(self, user_id: int, game_type: str) -> Dict:
"""获取游戏统计
Args:
user_id: 用户ID
game_type: 游戏类型
Returns:
统计数据字典
"""
cursor = self.conn.cursor()
cursor.execute(
"SELECT * FROM game_stats WHERE user_id = ? AND game_type = ?",
(user_id, game_type)
)
row = cursor.fetchone()
if row:
return dict(row)
else:
# 返回默认值
return {
'user_id': user_id,
'game_type': game_type,
'wins': 0,
'losses': 0,
'draws': 0,
'total_plays': 0
}
def update_game_stats(self, user_id: int, game_type: str,
win: bool = False, loss: bool = False, draw: bool = False):
"""更新游戏统计
Args:
user_id: 用户ID
game_type: 游戏类型
win: 是否获胜
loss: 是否失败
draw: 是否平局
"""
cursor = self.conn.cursor()
# 使用UPSERT语法
cursor.execute("""
INSERT INTO game_stats (user_id, game_type, wins, losses, draws, total_plays)
VALUES (?, ?, ?, ?, ?, 1)
ON CONFLICT(user_id, game_type)
DO UPDATE SET
wins = wins + ?,
losses = losses + ?,
draws = draws + ?,
total_plays = total_plays + 1
""", (user_id, game_type, int(win), int(loss), int(draw),
int(win), int(loss), int(draw)))
logger.debug(f"更新游戏统计: user_id={user_id}, game_type={game_type}")
# ===== 积分相关操作 =====
def get_user_points(self, user_id: int) -> Dict:
"""获取用户积分信息
Args:
user_id: 用户ID
Returns:
积分信息字典
"""
cursor = self.conn.cursor()
cursor.execute(
"SELECT * FROM user_points WHERE user_id = ?",
(user_id,)
)
row = cursor.fetchone()
if row:
return dict(row)
else:
# 创建新用户积分记录
current_time = int(time.time())
cursor.execute(
"INSERT INTO user_points (user_id, points, created_at, updated_at) VALUES (?, 0, ?, ?)",
(user_id, current_time, current_time)
)
return {
'user_id': user_id,
'points': 0,
'last_checkin_date': None,
'created_at': current_time,
'updated_at': current_time
}
def add_points(self, user_id: int, points: int, source: str, description: str = None) -> bool:
"""添加积分
Args:
user_id: 用户ID
points: 积分数量
source: 积分来源
description: 描述
Returns:
是否成功
"""
if points <= 0:
logger.warning(f"积分数量无效: {points}")
return False
cursor = self.conn.cursor()
current_time = int(time.time())
try:
# 确保用户存在
self.get_or_create_user(user_id)
# 更新用户积分
cursor.execute("""
INSERT OR REPLACE INTO user_points (user_id, points, created_at, updated_at)
VALUES (?, COALESCE((SELECT points FROM user_points WHERE user_id = ?), 0) + ?, ?, ?)
""", (user_id, user_id, points, current_time, current_time))
logger.info(f"用户 {user_id} 成功获得 {points} 积分,来源:{source}")
return True
except Exception as e:
logger.error(f"添加积分失败: user_id={user_id}, points={points}, error={e}", exc_info=True)
return False
def consume_points(self, user_id: int, points: int, source: str, description: str = None) -> bool:
"""消费积分
Args:
user_id: 用户ID
points: 积分数量
source: 消费来源
description: 描述
Returns:
是否成功
"""
if points <= 0:
return False
cursor = self.conn.cursor()
current_time = int(time.time())
try:
# 检查积分是否足够
cursor.execute(
"SELECT points FROM user_points WHERE user_id = ?",
(user_id,)
)
row = cursor.fetchone()
if not row or row[0] < points:
logger.warning(f"用户 {user_id} 积分不足,需要 {points},当前可用 {row[0] if row else 0}")
return False
# 消费积分
cursor.execute(
"UPDATE user_points SET points = points - ?, updated_at = ? WHERE user_id = ?",
(points, current_time, user_id)
)
logger.debug(f"用户 {user_id} 消费 {points} 积分,来源:{source}")
return True
except Exception as e:
logger.error(f"消费积分失败: {e}")
return False
def check_daily_checkin(self, user_id: int, date: str) -> bool:
"""检查用户是否已签到
Args:
user_id: 用户ID
date: 日期字符串 (YYYY-MM-DD)
Returns:
是否已签到
"""
cursor = self.conn.cursor()
cursor.execute(
"SELECT last_checkin_date FROM user_points WHERE user_id = ?",
(user_id,)
)
row = cursor.fetchone()
return row and row[0] == date
def daily_checkin(self, user_id: int, points: int) -> bool:
"""每日签到
Args:
user_id: 用户ID
points: 签到积分
Returns:
是否成功
"""
from datetime import datetime
today = datetime.now().strftime('%Y-%m-%d')
if self.check_daily_checkin(user_id, today):
logger.warning(f"用户 {user_id} 今日已签到")
return False
cursor = self.conn.cursor()
current_time = int(time.time())
try:
# 确保用户存在
self.get_or_create_user(user_id)
# 获取签到前积分
points_before = self.get_user_points(user_id)
logger.info(f"用户 {user_id} 签到前积分: {points_before['points']}")
# 更新积分和签到日期
cursor.execute("""
INSERT OR REPLACE INTO user_points (user_id, points, last_checkin_date, created_at, updated_at)
VALUES (?, COALESCE((SELECT points FROM user_points WHERE user_id = ?), 0) + ?, ?, ?, ?)
""", (user_id, user_id, points, today, current_time, current_time))
# 验证积分是否真的增加了
points_after = self.get_user_points(user_id)
logger.info(f"用户 {user_id} 签到后积分: {points_after['points']}")
if points_after['points'] > points_before['points']:
logger.info(f"用户 {user_id} 签到成功,积分增加: {points_after['points'] - points_before['points']}")
return True
else:
logger.error(f"用户 {user_id} 签到失败,积分未增加")
return False
except Exception as e:
logger.error(f"每日签到失败: {e}", exc_info=True)
return False
def get_points_leaderboard(self, limit: int = 10) -> List[Dict]:
"""获取积分排行榜
Args:
limit: 限制数量
Returns:
排行榜列表
"""
cursor = self.conn.cursor()
cursor.execute("""
SELECT u.user_id, u.username, up.points
FROM users u
LEFT JOIN user_points up ON u.user_id = up.user_id
ORDER BY COALESCE(up.points, 0) DESC
LIMIT ?
""", (limit,))
rows = cursor.fetchall()
return [dict(row) for row in rows]
def close(self):
"""关闭数据库连接"""
if self._conn:
self._conn.close()
self._conn = None
logger.info("数据库连接已关闭")
# 全局数据库实例
_db_instance: Optional[Database] = None
def get_db() -> Database:
"""获取全局数据库实例(单例模式)"""
global _db_instance
if _db_instance is None:
_db_instance = Database()
return _db_instance