35 lines
1.2 KiB
Python
35 lines
1.2 KiB
Python
"""中间件模块"""
|
|
import asyncio
|
|
import logging
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from config import MAX_CONCURRENT_REQUESTS
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConcurrencyLimitMiddleware(BaseHTTPMiddleware):
|
|
"""并发限制中间件 - 防止内存爆炸"""
|
|
|
|
def __init__(self, app, max_concurrent: int = MAX_CONCURRENT_REQUESTS):
|
|
super().__init__(app)
|
|
self.semaphore = asyncio.Semaphore(max_concurrent)
|
|
self.max_concurrent = max_concurrent
|
|
logger.info(f"并发限制中间件已启用,最大并发数:{max_concurrent}")
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
"""处理请求"""
|
|
async with self.semaphore:
|
|
try:
|
|
response = await call_next(request)
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"请求处理错误: {e}", exc_info=True)
|
|
return Response(
|
|
content='{"error": "Internal Server Error"}',
|
|
status_code=500,
|
|
media_type="application/json"
|
|
)
|
|
|