Files
WPSBot/games/ai_chat.py

454 lines
15 KiB
Python
Raw Normal View History

2025-10-29 23:56:57 +08:00
"""AI对话游戏模块"""
import json
import logging
import asyncio
import time
from pathlib import Path
from typing import Optional, Dict, Any, List
from games.base import BaseGame
from utils.parser import CommandParser
logger = logging.getLogger(__name__)
# 全局字典存储每个chat_id的延迟任务句柄
_pending_tasks: Dict[int, asyncio.Task] = {}
# 全局字典存储每个chat_id的待处理消息队列
_message_queues: Dict[int, List[Dict[str, Any]]] = {}
# 全局字典存储每个chat_id的ChatEngine实例
_chat_engines: Dict[int, Any] = {}
class AIChatGame(BaseGame):
"""AI对话游戏"""
def __init__(self):
"""初始化游戏"""
super().__init__()
self.config_file = Path(__file__).parent.parent / "data" / "ai_config.json"
self.wait_window = 10 # 固定10秒等待窗口
async def handle(self, command: str, chat_id: int, user_id: int) -> str:
"""处理AI对话指令
Args:
command: 指令 ".ai 问题" ".aiconfig host=xxx port=xxx model=xxx"
chat_id: 会话ID
user_id: 用户ID
Returns:
回复消息
"""
try:
# 提取指令和参数
cmd, args = CommandParser.extract_command_args(command)
args = args.strip()
# 判断是配置指令还是AI对话指令
if cmd == '.aiconfig':
return await self._handle_config(args, chat_id, user_id)
else:
# .ai 指令
return await self._handle_ai(args, chat_id, user_id)
except Exception as e:
logger.error(f"处理AI对话指令错误: {e}", exc_info=True)
return f"❌ 处理指令出错: {str(e)}"
async def _handle_ai(self, content: str, chat_id: int, user_id: int) -> str:
"""处理AI对话请求
Args:
content: 消息内容
chat_id: 会话ID
user_id: 用户ID
Returns:
回复消息
"""
# 如果内容为空,返回帮助信息
if not content:
return self.get_help()
# 将消息加入队列
self._add_to_queue(chat_id, user_id, content)
# 取消旧的延迟任务(如果存在)
if chat_id in _pending_tasks:
old_task = _pending_tasks[chat_id]
if not old_task.done():
old_task.cancel()
try:
await old_task
except asyncio.CancelledError:
pass
# 创建新的延迟任务
task = asyncio.create_task(self._delayed_response(chat_id))
_pending_tasks[chat_id] = task
# 不返回确认消息,静默处理
return ""
2025-10-29 23:56:57 +08:00
async def _handle_config(self, args: str, chat_id: int, user_id: int) -> str:
"""处理配置请求
Args:
args: 配置参数格式如 "host=localhost port=11434 model=llama3.1"
chat_id: 会话ID
user_id: 用户ID
Returns:
配置确认消息
"""
if not args:
return "❌ 请提供配置参数\n\n格式:`.aiconfig host=xxx port=xxx model=xxx`\n\n示例:`.aiconfig host=localhost port=11434 model=llama3.1`"
# 解析配置参数
config_updates = {}
parts = args.split()
for part in parts:
if '=' in part:
key, value = part.split('=', 1)
key = key.strip().lower()
value = value.strip()
if key == 'host':
config_updates['host'] = value
elif key == 'port':
try:
config_updates['port'] = int(value)
except ValueError:
return f"❌ 端口号必须是数字:{value}"
elif key == 'model':
config_updates['model'] = value
else:
return f"❌ 未知的配置项:{key}\n\n支持的配置项host, port, model"
if not config_updates:
return "❌ 未提供有效的配置参数"
# 加载现有配置
current_config = self._load_config()
# 更新配置
current_config.update(config_updates)
# 保存配置
if self._save_config(current_config):
# 清除所有ChatEngine缓存配置变更需要重新创建
_chat_engines.clear()
return f"✅ 配置已更新\n\n**当前配置**\n- 地址:{current_config['host']}\n- 端口:{current_config['port']}\n- 模型:{current_config['model']}"
else:
return "❌ 保存配置失败,请稍后重试"
def _add_to_queue(self, chat_id: int, user_id: int, content: str) -> None:
"""将消息加入等待队列
Args:
chat_id: 会话ID
user_id: 用户ID
content: 消息内容
"""
if chat_id not in _message_queues:
_message_queues[chat_id] = []
_message_queues[chat_id].append({
"user_id": user_id,
"content": content,
"timestamp": int(time.time())
})
async def _delayed_response(self, chat_id: int) -> None:
"""延迟回答任务
Args:
chat_id: 会话ID
"""
try:
# 等待固定时间窗口
await asyncio.sleep(self.wait_window)
# 检查队列中是否有消息
if chat_id in _message_queues and _message_queues[chat_id]:
# 生成回答
response = await self._generate_response(chat_id)
# 清空队列
_message_queues[chat_id] = []
# 发送回答
if response:
from utils.message import get_message_sender
sender = get_message_sender()
await sender.send_text(response)
# 从pending_tasks中移除任务句柄
if chat_id in _pending_tasks:
del _pending_tasks[chat_id]
except asyncio.CancelledError:
# 任务被取消,正常情况,不需要记录错误
logger.debug(f"延迟任务被取消: chat_id={chat_id}")
if chat_id in _pending_tasks:
del _pending_tasks[chat_id]
except Exception as e:
logger.error(f"延迟回答任务错误: {e}", exc_info=True)
if chat_id in _pending_tasks:
del _pending_tasks[chat_id]
async def _generate_response(self, chat_id: int) -> Optional[str]:
"""使用LLM生成回答
Args:
chat_id: 会话ID
Returns:
回答文本
"""
try:
# 获取队列消息
if chat_id not in _message_queues or not _message_queues[chat_id]:
return None
messages = _message_queues[chat_id].copy()
# 获取ChatEngine实例
chat_engine = self._get_chat_engine(chat_id)
if not chat_engine:
return "❌ AI服务初始化失败请检查配置"
# 将消息按用户角色格式化并添加到ChatMemoryBuffer
# 构建合并的消息内容(包含用户信息)
merged_content = ""
for msg in messages:
user_id = msg['user_id']
role = self._get_user_role(chat_id, user_id)
merged_content += f"[{role}]: {msg['content']}\n"
# 去掉最后的换行
merged_content = merged_content.strip()
# 调用ChatEngine生成回答
# chat_engine是一个字典包含llm, memory, system_prompt
llm = chat_engine['llm']
memory = chat_engine['memory']
system_prompt = chat_engine['system_prompt']
# 构建完整的消息(包含系统提示和历史对话)
full_message = f"{system_prompt}\n\n{merged_content}"
# 使用LLM生成回答同步调用在线程池中执行
response = await asyncio.to_thread(llm.complete, full_message)
# 返回回答文本
return str(response)
except Exception as e:
logger.error(f"生成AI回答错误: {e}", exc_info=True)
return f"❌ 生成回答时出错: {str(e)}"
def _get_chat_engine(self, chat_id: int) -> Any:
"""获取或创建ChatEngine实例
Args:
chat_id: 会话ID
Returns:
ChatEngine实例
"""
# 检查是否已存在
if chat_id in _chat_engines:
return _chat_engines[chat_id]
try:
# 加载配置
config = self._load_config()
# 导入llama_index模块
from llama_index.llms.ollama import Ollama
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core import ChatPromptTemplate, Settings
# 创建Ollama LLM实例
llm = Ollama(
model=config['model'],
base_url=f"http://{config['host']}:{config['port']}"
)
# 设置全局LLM
Settings.llm = llm
# 创建ChatMemoryBuffer设置足够的token_limit确保保留30+轮对话)
memory = ChatMemoryBuffer.from_defaults(token_limit=8000)
# 系统提示
system_prompt = (
"这是一个多用户对话场景,不同用户的发言会用不同的角色标识(如'用户1''用户2'等)。"
"你需要理解不同用户的发言内容,并根据上下文给出合适的回复。"
"请用自然、友好的方式与用户交流。"
)
# 创建对话引擎
# 由于llama_index的API可能在不同版本有变化这里使用基本的chat接口
# 实际使用时可能需要根据llama_index的版本调整
chat_engine = {
'llm': llm,
'memory': memory,
'system_prompt': system_prompt
}
# 存储到全局字典
_chat_engines[chat_id] = chat_engine
return chat_engine
except ImportError as e:
logger.error(f"导入llama_index模块失败: {e}")
return None
except Exception as e:
logger.error(f"创建ChatEngine失败: {e}", exc_info=True)
return None
def _get_user_role(self, chat_id: int, user_id: int) -> str:
"""获取用户角色名称(创建或获取映射)
Args:
chat_id: 会话ID
user_id: 用户ID
Returns:
角色名称
"""
# 获取现有映射
user_mapping, user_count = self._get_user_mapping(chat_id)
user_id_str = str(user_id)
# 如果用户已存在,返回角色名称
if user_id_str in user_mapping:
return user_mapping[user_id_str]
# 新用户,分配角色
user_count += 1
role_name = f"用户{user_count}"
user_mapping[user_id_str] = role_name
# 保存到数据库
state_data = {
"user_mapping": user_mapping,
"user_count": user_count
}
self.db.save_game_state(chat_id, 0, 'ai_chat', state_data)
return role_name
def _get_user_mapping(self, chat_id: int) -> tuple[Dict[str, str], int]:
"""获取用户角色映射和计数
Args:
chat_id: 会话ID
Returns:
(用户映射字典, 用户计数)
"""
# 从数据库获取映射
state = self.db.get_game_state(chat_id, 0, 'ai_chat')
if state and state.get('state_data'):
user_mapping = state['state_data'].get('user_mapping', {})
user_count = state['state_data'].get('user_count', 0)
else:
user_mapping = {}
user_count = 0
return user_mapping, user_count
def _load_config(self) -> Dict[str, Any]:
"""从JSON文件加载配置
Returns:
配置字典
"""
# 如果文件不存在,创建默认配置
if not self.config_file.exists():
default_config = {
"host": "localhost",
"port": 11434,
"model": "llama3.1"
}
self._save_config(default_config)
return default_config
try:
with open(self.config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
# 确保所有必需的字段存在
if 'host' not in config:
config['host'] = "localhost"
if 'port' not in config:
config['port'] = 11434
if 'model' not in config:
config['model'] = "llama3.1"
return config
except Exception as e:
logger.error(f"加载配置文件失败: {e}", exc_info=True)
# 返回默认配置
return {
"host": "localhost",
"port": 11434,
"model": "llama3.1"
}
def _save_config(self, config: Dict[str, Any]) -> bool:
"""保存配置到JSON文件
Args:
config: 配置字典
Returns:
是否成功
"""
try:
# 确保目录存在
self.config_file.parent.mkdir(parents=True, exist_ok=True)
# 写入JSON文件
with open(self.config_file, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=4, ensure_ascii=False)
return True
except Exception as e:
logger.error(f"保存配置文件失败: {e}", exc_info=True)
return False
def get_help(self) -> str:
"""获取帮助信息
Returns:
帮助文本
"""
return """## 🤖 AI对话系统帮助
### 基本用法
- `.ai <问题>` - 向AI提问支持多用户对话等待10秒后回答
- `.aiconfig host=xxx port=xxx model=xxx` - 配置Ollama服务地址和模型
### 配置示例
\`.aiconfig host=localhost port=11434 model=llama3.1\`
### 说明
- 多个用户可以在同一个会话中提问
- 系统会等待10秒收集所有问题后统一回答
- 如果在等待期间有新消息会重新计时
---
💡 提示确保Ollama服务已启动并配置正确
"""