11"""
22Security middleware for API protection
33"""
4- from typing import Callable
4+ import time
5+ import hashlib
6+ import hmac
7+ import json
8+ from typing import Callable , Dict , Set , Optional
59from starlette .middleware .base import BaseHTTPMiddleware
610from starlette .requests import Request
7- from starlette .responses import Response
11+ from starlette .responses import Response , JSONResponse
812from starlette .types import ASGIApp
13+ import structlog
14+
15+ logger = structlog .get_logger ()
916
1017
1118class SecurityHeadersMiddleware (BaseHTTPMiddleware ):
@@ -66,10 +73,19 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
6673 return response
6774
6875
76+ class APIKeyQuota :
77+ """API Key quota configuration."""
78+ def __init__ (self , calls_per_hour : int = 1000 , calls_per_day : int = 10000 ,
79+ max_concurrent_jobs : int = 5 , max_file_size_mb : int = 1000 ):
80+ self .calls_per_hour = calls_per_hour
81+ self .calls_per_day = calls_per_day
82+ self .max_concurrent_jobs = max_concurrent_jobs
83+ self .max_file_size_mb = max_file_size_mb
84+
85+
6986class RateLimitMiddleware (BaseHTTPMiddleware ):
7087 """
71- Simple rate limiting middleware for additional protection.
72- Note: Primary rate limiting is handled by KrakenD API Gateway.
88+ Enhanced rate limiting middleware with API key quotas.
7389 """
7490
7591 def __init__ (
@@ -78,61 +94,273 @@ def __init__(
7894 calls : int = 1000 ,
7995 period : int = 3600 , # 1 hour
8096 enabled : bool = True ,
97+ redis_client = None , # Redis client for distributed rate limiting
8198 ):
8299 super ().__init__ (app )
83100 self .calls = calls
84101 self .period = period
85102 self .enabled = enabled
86- self .clients = {} # Simple in-memory store (use Redis in production)
103+ self .redis_client = redis_client
104+ self .clients = {} # Fallback in-memory store
105+
106+ # Default quotas for different API key tiers
107+ self .default_quotas = {
108+ 'free' : APIKeyQuota (calls_per_hour = 100 , calls_per_day = 1000 , max_concurrent_jobs = 2 , max_file_size_mb = 100 ),
109+ 'basic' : APIKeyQuota (calls_per_hour = 500 , calls_per_day = 5000 , max_concurrent_jobs = 5 , max_file_size_mb = 500 ),
110+ 'premium' : APIKeyQuota (calls_per_hour = 2000 , calls_per_day = 20000 , max_concurrent_jobs = 10 , max_file_size_mb = 2000 ),
111+ 'enterprise' : APIKeyQuota (calls_per_hour = 10000 , calls_per_day = 100000 , max_concurrent_jobs = 50 , max_file_size_mb = 10000 )
112+ }
87113
88114 async def dispatch (self , request : Request , call_next : Callable ) -> Response :
89- """Apply rate limiting based on client IP ."""
115+ """Apply enhanced rate limiting with API key quotas ."""
90116 if not self .enabled :
91117 return await call_next (request )
92118
93- # Get client IP
119+ # Get client identifier (IP + API key if available)
94120 client_ip = request .client .host
95121 if "X-Forwarded-For" in request .headers :
96122 client_ip = request .headers ["X-Forwarded-For" ].split ("," )[0 ].strip ()
97123
98- # Simple rate limiting logic (in production, use Redis)
124+ api_key = request .headers .get ("X-API-Key" ) or request .query_params .get ("api_key" )
125+ client_id = f"{ client_ip } :{ api_key } " if api_key else client_ip
126+
127+ # Get appropriate quota limits
128+ quota = await self ._get_client_quota (api_key )
129+
99130 import time
100131 current_time = time .time ()
132+ hour_key = f"{ client_id } :hour:{ int (current_time // 3600 )} "
133+ day_key = f"{ client_id } :day:{ int (current_time // 86400 )} "
101134
102- # Clean old entries (simple cleanup)
135+ # Use Redis for distributed rate limiting if available
136+ if self .redis_client :
137+ try :
138+ # Check hourly limit
139+ hourly_count = await self .redis_client .get (hour_key ) or 0
140+ daily_count = await self .redis_client .get (day_key ) or 0
141+
142+ hourly_count = int (hourly_count )
143+ daily_count = int (daily_count )
144+
145+ # Check limits
146+ if hourly_count >= quota .calls_per_hour :
147+ return self ._rate_limit_response (quota .calls_per_hour , "hour" , hourly_count )
148+
149+ if daily_count >= quota .calls_per_day :
150+ return self ._rate_limit_response (quota .calls_per_day , "day" , daily_count )
151+
152+ # Increment counters
153+ await self .redis_client .incr (hour_key )
154+ await self .redis_client .expire (hour_key , 3600 ) # 1 hour TTL
155+ await self .redis_client .incr (day_key )
156+ await self .redis_client .expire (day_key , 86400 ) # 1 day TTL
157+
158+ except Exception as e :
159+ # Fall back to in-memory if Redis fails
160+ import structlog
161+ logger = structlog .get_logger ()
162+ logger .warning ("Redis rate limiting failed, using fallback" , error = str (e ))
163+ return await self ._fallback_rate_limiting (client_id , quota , current_time , call_next , request )
164+ else :
165+ # Use in-memory fallback
166+ return await self ._fallback_rate_limiting (client_id , quota , current_time , call_next , request )
167+
168+ # Add rate limit headers
169+ response = await call_next (request )
170+ response .headers ["X-RateLimit-Limit-Hour" ] = str (quota .calls_per_hour )
171+ response .headers ["X-RateLimit-Limit-Day" ] = str (quota .calls_per_day )
172+ response .headers ["X-RateLimit-Remaining-Hour" ] = str (max (0 , quota .calls_per_hour - hourly_count - 1 ))
173+ response .headers ["X-RateLimit-Remaining-Day" ] = str (max (0 , quota .calls_per_day - daily_count - 1 ))
174+
175+ return response
176+
177+ async def _get_client_quota (self , api_key : str = None ) -> APIKeyQuota :
178+ """Get quota configuration for client based on API key tier."""
179+ if not api_key :
180+ return self .default_quotas ['free' ]
181+
182+ # In production, look up API key tier from database
183+ # For now, return based on key prefix or default to basic
184+ if api_key .startswith ('ent_' ):
185+ return self .default_quotas ['enterprise' ]
186+ elif api_key .startswith ('prem_' ):
187+ return self .default_quotas ['premium' ]
188+ elif api_key .startswith ('basic_' ):
189+ return self .default_quotas ['basic' ]
190+ else :
191+ return self .default_quotas ['basic' ] # Default for unknown keys
192+
193+ def _rate_limit_response (self , limit : int , period : str , current_count : int ):
194+ """Create rate limit exceeded response."""
195+ from starlette .responses import JSONResponse
196+ return JSONResponse (
197+ status_code = 429 ,
198+ content = {
199+ "error" : {
200+ "code" : "RATE_LIMIT_EXCEEDED" ,
201+ "message" : f"Rate limit exceeded. Maximum { limit } requests per { period } ." ,
202+ "type" : "RateLimitError" ,
203+ "limit" : limit ,
204+ "period" : period ,
205+ "current_usage" : current_count
206+ }
207+ },
208+ headers = {
209+ f"X-RateLimit-Limit-{ period .title ()} " : str (limit ),
210+ f"X-RateLimit-Remaining-{ period .title ()} " : "0" ,
211+ "Retry-After" : "3600" if period == "hour" else "86400"
212+ }
213+ )
214+
215+ async def _fallback_rate_limiting (self , client_id : str , quota : APIKeyQuota ,
216+ current_time : float , call_next : Callable , request : Request ):
217+ """Fallback in-memory rate limiting when Redis is unavailable."""
218+ # Clean old entries
103219 self .clients = {
104- ip : data for ip , data in self .clients .items ()
220+ cid : data for cid , data in self .clients .items ()
105221 if current_time - data ["window_start" ] < self .period
106222 }
107223
108- # Check rate limit
109- if client_ip in self .clients :
110- client_data = self .clients [client_ip ]
224+ # Check rate limit (simplified to hourly only for fallback)
225+ if client_id in self .clients :
226+ client_data = self .clients [client_id ]
111227 if current_time - client_data ["window_start" ] < self .period :
112- if client_data ["requests" ] >= self .calls :
113- from starlette .responses import JSONResponse
114- return JSONResponse (
115- status_code = 429 ,
116- content = {
117- "error" : {
118- "code" : "RATE_LIMIT_EXCEEDED" ,
119- "message" : f"Rate limit exceeded. Maximum { self .calls } requests per hour." ,
120- "type" : "RateLimitError"
121- }
122- }
123- )
228+ if client_data ["requests" ] >= quota .calls_per_hour :
229+ return self ._rate_limit_response (quota .calls_per_hour , "hour" , client_data ["requests" ])
124230 client_data ["requests" ] += 1
125231 else :
126232 # Reset window
127- self .clients [client_ip ] = {
233+ self .clients [client_id ] = {
128234 "requests" : 1 ,
129235 "window_start" : current_time
130236 }
131237 else :
132238 # New client
133- self .clients [client_ip ] = {
239+ self .clients [client_id ] = {
134240 "requests" : 1 ,
135241 "window_start" : current_time
136242 }
137243
138- return await call_next (request )
244+ return await call_next (request )
245+
246+
247+ class InputSanitizationMiddleware (BaseHTTPMiddleware ):
248+ """Middleware for sanitizing and validating input data."""
249+
250+ def __init__ (self , app : ASGIApp , max_body_size : int = 100 * 1024 * 1024 ): # 100MB default
251+ super ().__init__ (app )
252+ self .max_body_size = max_body_size
253+
254+ async def dispatch (self , request : Request , call_next : Callable ) -> Response :
255+ """Sanitize request data."""
256+ try :
257+ # Check content length
258+ content_length = request .headers .get ('content-length' )
259+ if content_length and int (content_length ) > self .max_body_size :
260+ return JSONResponse (
261+ status_code = 413 ,
262+ content = {
263+ "error" : {
264+ "code" : "PAYLOAD_TOO_LARGE" ,
265+ "message" : f"Request body too large. Maximum size: { self .max_body_size } bytes" ,
266+ "type" : "RequestError"
267+ }
268+ }
269+ )
270+
271+ # Validate Content-Type for POST/PUT requests
272+ if request .method in ['POST' , 'PUT' , 'PATCH' ]:
273+ content_type = request .headers .get ('content-type' , '' )
274+ if not content_type .startswith (('application/json' , 'multipart/form-data' , 'application/x-www-form-urlencoded' )):
275+ return JSONResponse (
276+ status_code = 415 ,
277+ content = {
278+ "error" : {
279+ "code" : "UNSUPPORTED_MEDIA_TYPE" ,
280+ "message" : "Unsupported media type" ,
281+ "type" : "RequestError"
282+ }
283+ }
284+ )
285+
286+ return await call_next (request )
287+
288+ except Exception as e :
289+ logger .error ("Input sanitization failed" , error = str (e ))
290+ return JSONResponse (
291+ status_code = 400 ,
292+ content = {
293+ "error" : {
294+ "code" : "BAD_REQUEST" ,
295+ "message" : "Invalid request format" ,
296+ "type" : "RequestError"
297+ }
298+ }
299+ )
300+
301+
302+ class SecurityAuditMiddleware (BaseHTTPMiddleware ):
303+ """Middleware for security auditing and monitoring."""
304+
305+ def __init__ (self , app : ASGIApp , log_suspicious_activity : bool = True ):
306+ super ().__init__ (app )
307+ self .log_suspicious_activity = log_suspicious_activity
308+ self .suspicious_patterns = [
309+ r'\.\./' , # Directory traversal
310+ r'<script' , # XSS attempts
311+ r'union\s+select' , # SQL injection
312+ r'javascript:' , # XSS
313+ r'eval\s*\(' , # Code injection
314+ r'/etc/passwd' , # File access attempts
315+ ]
316+
317+ async def dispatch (self , request : Request , call_next : Callable ) -> Response :
318+ """Monitor and audit security events."""
319+ start_time = time .time ()
320+
321+ # Check for suspicious patterns
322+ if self .log_suspicious_activity :
323+ self ._check_for_suspicious_activity (request )
324+
325+ response = await call_next (request )
326+
327+ # Log security events
328+ processing_time = time .time () - start_time
329+
330+ if processing_time > 30 : # Slow request detection
331+ logger .warning (
332+ "Slow request detected" ,
333+ path = request .url .path ,
334+ processing_time = processing_time ,
335+ client_ip = self ._get_client_ip (request )
336+ )
337+
338+ if response .status_code == 401 :
339+ logger .warning (
340+ "Authentication failed" ,
341+ path = request .url .path ,
342+ client_ip = self ._get_client_ip (request )
343+ )
344+
345+ return response
346+
347+ def _check_for_suspicious_activity (self , request : Request ):
348+ """Check for suspicious patterns in the request."""
349+ import re
350+
351+ # Check URL path
352+ for pattern in self .suspicious_patterns :
353+ if re .search (pattern , request .url .path , re .IGNORECASE ):
354+ logger .warning (
355+ "Suspicious pattern in URL" ,
356+ pattern = pattern ,
357+ url = request .url .path ,
358+ client_ip = self ._get_client_ip (request )
359+ )
360+
361+ def _get_client_ip (self , request : Request ) -> str :
362+ """Get client IP address."""
363+ forwarded_for = request .headers .get ('x-forwarded-for' )
364+ if forwarded_for :
365+ return forwarded_for .split (',' )[0 ].strip ()
366+ return request .client .host if request .client else 'unknown'
0 commit comments