diff --git a/api/middleware/security.py b/api/middleware/security.py index 09e44fb..c79992b 100644 --- a/api/middleware/security.py +++ b/api/middleware/security.py @@ -1,11 +1,18 @@ """ Security middleware for API protection """ -from typing import Callable +import time +import hashlib +import hmac +import json +from typing import Callable, Dict, Set, Optional from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import Response, JSONResponse from starlette.types import ASGIApp +import structlog + +logger = structlog.get_logger() class SecurityHeadersMiddleware(BaseHTTPMiddleware): @@ -66,10 +73,19 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: return response +class APIKeyQuota: + """API Key quota configuration.""" + def __init__(self, calls_per_hour: int = 1000, calls_per_day: int = 10000, + max_concurrent_jobs: int = 5, max_file_size_mb: int = 1000): + self.calls_per_hour = calls_per_hour + self.calls_per_day = calls_per_day + self.max_concurrent_jobs = max_concurrent_jobs + self.max_file_size_mb = max_file_size_mb + + class RateLimitMiddleware(BaseHTTPMiddleware): """ - Simple rate limiting middleware for additional protection. - Note: Primary rate limiting is handled by KrakenD API Gateway. + Enhanced rate limiting middleware with API key quotas. """ def __init__( @@ -78,61 +94,273 @@ def __init__( calls: int = 1000, period: int = 3600, # 1 hour enabled: bool = True, + redis_client = None, # Redis client for distributed rate limiting ): super().__init__(app) self.calls = calls self.period = period self.enabled = enabled - self.clients = {} # Simple in-memory store (use Redis in production) + self.redis_client = redis_client + self.clients = {} # Fallback in-memory store + + # Default quotas for different API key tiers + self.default_quotas = { + 'free': APIKeyQuota(calls_per_hour=100, calls_per_day=1000, max_concurrent_jobs=2, max_file_size_mb=100), + 'basic': APIKeyQuota(calls_per_hour=500, calls_per_day=5000, max_concurrent_jobs=5, max_file_size_mb=500), + 'premium': APIKeyQuota(calls_per_hour=2000, calls_per_day=20000, max_concurrent_jobs=10, max_file_size_mb=2000), + 'enterprise': APIKeyQuota(calls_per_hour=10000, calls_per_day=100000, max_concurrent_jobs=50, max_file_size_mb=10000) + } async def dispatch(self, request: Request, call_next: Callable) -> Response: - """Apply rate limiting based on client IP.""" + """Apply enhanced rate limiting with API key quotas.""" if not self.enabled: return await call_next(request) - # Get client IP + # Get client identifier (IP + API key if available) client_ip = request.client.host if "X-Forwarded-For" in request.headers: client_ip = request.headers["X-Forwarded-For"].split(",")[0].strip() - # Simple rate limiting logic (in production, use Redis) + api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key") + client_id = f"{client_ip}:{api_key}" if api_key else client_ip + + # Get appropriate quota limits + quota = await self._get_client_quota(api_key) + import time current_time = time.time() + hour_key = f"{client_id}:hour:{int(current_time // 3600)}" + day_key = f"{client_id}:day:{int(current_time // 86400)}" - # Clean old entries (simple cleanup) + # Use Redis for distributed rate limiting if available + if self.redis_client: + try: + # Check hourly limit + hourly_count = await self.redis_client.get(hour_key) or 0 + daily_count = await self.redis_client.get(day_key) or 0 + + hourly_count = int(hourly_count) + daily_count = int(daily_count) + + # Check limits + if hourly_count >= quota.calls_per_hour: + return self._rate_limit_response(quota.calls_per_hour, "hour", hourly_count) + + if daily_count >= quota.calls_per_day: + return self._rate_limit_response(quota.calls_per_day, "day", daily_count) + + # Increment counters + await self.redis_client.incr(hour_key) + await self.redis_client.expire(hour_key, 3600) # 1 hour TTL + await self.redis_client.incr(day_key) + await self.redis_client.expire(day_key, 86400) # 1 day TTL + + except Exception as e: + # Fall back to in-memory if Redis fails + import structlog + logger = structlog.get_logger() + logger.warning("Redis rate limiting failed, using fallback", error=str(e)) + return await self._fallback_rate_limiting(client_id, quota, current_time, call_next, request) + else: + # Use in-memory fallback + return await self._fallback_rate_limiting(client_id, quota, current_time, call_next, request) + + # Add rate limit headers + response = await call_next(request) + response.headers["X-RateLimit-Limit-Hour"] = str(quota.calls_per_hour) + response.headers["X-RateLimit-Limit-Day"] = str(quota.calls_per_day) + response.headers["X-RateLimit-Remaining-Hour"] = str(max(0, quota.calls_per_hour - hourly_count - 1)) + response.headers["X-RateLimit-Remaining-Day"] = str(max(0, quota.calls_per_day - daily_count - 1)) + + return response + + async def _get_client_quota(self, api_key: str = None) -> APIKeyQuota: + """Get quota configuration for client based on API key tier.""" + if not api_key: + return self.default_quotas['free'] + + # In production, look up API key tier from database + # For now, return based on key prefix or default to basic + if api_key.startswith('ent_'): + return self.default_quotas['enterprise'] + elif api_key.startswith('prem_'): + return self.default_quotas['premium'] + elif api_key.startswith('basic_'): + return self.default_quotas['basic'] + else: + return self.default_quotas['basic'] # Default for unknown keys + + def _rate_limit_response(self, limit: int, period: str, current_count: int): + """Create rate limit exceeded response.""" + from starlette.responses import JSONResponse + return JSONResponse( + status_code=429, + content={ + "error": { + "code": "RATE_LIMIT_EXCEEDED", + "message": f"Rate limit exceeded. Maximum {limit} requests per {period}.", + "type": "RateLimitError", + "limit": limit, + "period": period, + "current_usage": current_count + } + }, + headers={ + f"X-RateLimit-Limit-{period.title()}": str(limit), + f"X-RateLimit-Remaining-{period.title()}": "0", + "Retry-After": "3600" if period == "hour" else "86400" + } + ) + + async def _fallback_rate_limiting(self, client_id: str, quota: APIKeyQuota, + current_time: float, call_next: Callable, request: Request): + """Fallback in-memory rate limiting when Redis is unavailable.""" + # Clean old entries self.clients = { - ip: data for ip, data in self.clients.items() + cid: data for cid, data in self.clients.items() if current_time - data["window_start"] < self.period } - # Check rate limit - if client_ip in self.clients: - client_data = self.clients[client_ip] + # Check rate limit (simplified to hourly only for fallback) + if client_id in self.clients: + client_data = self.clients[client_id] if current_time - client_data["window_start"] < self.period: - if client_data["requests"] >= self.calls: - from starlette.responses import JSONResponse - return JSONResponse( - status_code=429, - content={ - "error": { - "code": "RATE_LIMIT_EXCEEDED", - "message": f"Rate limit exceeded. Maximum {self.calls} requests per hour.", - "type": "RateLimitError" - } - } - ) + if client_data["requests"] >= quota.calls_per_hour: + return self._rate_limit_response(quota.calls_per_hour, "hour", client_data["requests"]) client_data["requests"] += 1 else: # Reset window - self.clients[client_ip] = { + self.clients[client_id] = { "requests": 1, "window_start": current_time } else: # New client - self.clients[client_ip] = { + self.clients[client_id] = { "requests": 1, "window_start": current_time } - return await call_next(request) \ No newline at end of file + return await call_next(request) + + +class InputSanitizationMiddleware(BaseHTTPMiddleware): + """Middleware for sanitizing and validating input data.""" + + def __init__(self, app: ASGIApp, max_body_size: int = 100 * 1024 * 1024): # 100MB default + super().__init__(app) + self.max_body_size = max_body_size + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """Sanitize request data.""" + try: + # Check content length + content_length = request.headers.get('content-length') + if content_length and int(content_length) > self.max_body_size: + return JSONResponse( + status_code=413, + content={ + "error": { + "code": "PAYLOAD_TOO_LARGE", + "message": f"Request body too large. Maximum size: {self.max_body_size} bytes", + "type": "RequestError" + } + } + ) + + # Validate Content-Type for POST/PUT requests + if request.method in ['POST', 'PUT', 'PATCH']: + content_type = request.headers.get('content-type', '') + if not content_type.startswith(('application/json', 'multipart/form-data', 'application/x-www-form-urlencoded')): + return JSONResponse( + status_code=415, + content={ + "error": { + "code": "UNSUPPORTED_MEDIA_TYPE", + "message": "Unsupported media type", + "type": "RequestError" + } + } + ) + + return await call_next(request) + + except Exception as e: + logger.error("Input sanitization failed", error=str(e)) + return JSONResponse( + status_code=400, + content={ + "error": { + "code": "BAD_REQUEST", + "message": "Invalid request format", + "type": "RequestError" + } + } + ) + + +class SecurityAuditMiddleware(BaseHTTPMiddleware): + """Middleware for security auditing and monitoring.""" + + def __init__(self, app: ASGIApp, log_suspicious_activity: bool = True): + super().__init__(app) + self.log_suspicious_activity = log_suspicious_activity + self.suspicious_patterns = [ + r'\.\./', # Directory traversal + r' Response: + """Monitor and audit security events.""" + start_time = time.time() + + # Check for suspicious patterns + if self.log_suspicious_activity: + self._check_for_suspicious_activity(request) + + response = await call_next(request) + + # Log security events + processing_time = time.time() - start_time + + if processing_time > 30: # Slow request detection + logger.warning( + "Slow request detected", + path=request.url.path, + processing_time=processing_time, + client_ip=self._get_client_ip(request) + ) + + if response.status_code == 401: + logger.warning( + "Authentication failed", + path=request.url.path, + client_ip=self._get_client_ip(request) + ) + + return response + + def _check_for_suspicious_activity(self, request: Request): + """Check for suspicious patterns in the request.""" + import re + + # Check URL path + for pattern in self.suspicious_patterns: + if re.search(pattern, request.url.path, re.IGNORECASE): + logger.warning( + "Suspicious pattern in URL", + pattern=pattern, + url=request.url.path, + client_ip=self._get_client_ip(request) + ) + + def _get_client_ip(self, request: Request) -> str: + """Get client IP address.""" + forwarded_for = request.headers.get('x-forwarded-for') + if forwarded_for: + return forwarded_for.split(',')[0].strip() + return request.client.host if request.client else 'unknown' \ No newline at end of file diff --git a/api/security_config.py b/api/security_config.py new file mode 100644 index 0000000..c151ea0 --- /dev/null +++ b/api/security_config.py @@ -0,0 +1,272 @@ +""" +Security configuration and setup for FFmpeg API +""" +import os +from typing import Dict, Any, List +from fastapi import FastAPI, Request, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.base import BaseHTTPMiddleware + +from api.middleware.security import ( + SecurityHeadersMiddleware, + RateLimitMiddleware, + InputSanitizationMiddleware, + SecurityAuditMiddleware +) +from api.utils.error_handler import ( + ProductionErrorHandler, + ErrorLevel, + set_debug_mode +) +from api.utils.validators import SecurityError +import structlog + +logger = structlog.get_logger() + + +class SecurityConfig: + """Central security configuration for the API.""" + + def __init__(self): + # Environment-based settings + self.debug_mode = os.getenv('DEBUG', 'false').lower() == 'true' + self.environment = os.getenv('ENVIRONMENT', 'production') + + # Rate limiting settings + self.rate_limit_enabled = os.getenv('RATE_LIMIT_ENABLED', 'true').lower() == 'true' + self.rate_limit_calls = int(os.getenv('RATE_LIMIT_CALLS', '1000')) + self.rate_limit_period = int(os.getenv('RATE_LIMIT_PERIOD', '3600')) + + # Security headers settings + self.csp_policy = os.getenv( + 'CSP_POLICY', + "default-src 'self'; script-src 'self'; object-src 'none';" + ) + self.hsts_max_age = int(os.getenv('HSTS_MAX_AGE', '31536000')) + + # Input validation settings + self.max_body_size = int(os.getenv('MAX_BODY_SIZE', str(100 * 1024 * 1024))) # 100MB + + # CORS settings + self.cors_origins = os.getenv('CORS_ORIGINS', '').split(',') if os.getenv('CORS_ORIGINS') else ['*'] + self.cors_allow_credentials = os.getenv('CORS_ALLOW_CREDENTIALS', 'false').lower() == 'true' + + # Error handling + self.error_handler = ProductionErrorHandler(debug_mode=self.debug_mode) + set_debug_mode(self.debug_mode) + + logger.info("Security configuration initialized", + debug_mode=self.debug_mode, + environment=self.environment) + + def configure_app(self, app: FastAPI) -> FastAPI: + """Apply all security configurations to the FastAPI app.""" + + # Add security middleware in correct order (reverse order of execution) + + # 1. Security audit middleware (outermost - logs everything) + app.add_middleware( + SecurityAuditMiddleware, + log_suspicious_activity=True + ) + + # 2. Rate limiting middleware + if self.rate_limit_enabled: + app.add_middleware( + RateLimitMiddleware, + calls=self.rate_limit_calls, + period=self.rate_limit_period, + enabled=True + ) + + # 3. Input sanitization middleware + app.add_middleware( + InputSanitizationMiddleware, + max_body_size=self.max_body_size + ) + + # 4. Security headers middleware + app.add_middleware( + SecurityHeadersMiddleware, + csp_policy=self.csp_policy, + hsts_max_age=self.hsts_max_age, + enable_hsts=True, + enable_nosniff=True, + enable_xss_protection=True, + enable_frame_options=True + ) + + # 5. CORS middleware (innermost) + if self.cors_origins: + app.add_middleware( + CORSMiddleware, + allow_origins=self.cors_origins, + allow_credentials=self.cors_allow_credentials, + allow_methods=["GET", "POST", "PUT", "DELETE"], + allow_headers=["*"], + ) + + # Add global exception handlers + self._add_exception_handlers(app) + + logger.info("Security middleware configured successfully") + return app + + def _add_exception_handlers(self, app: FastAPI): + """Add global exception handlers with proper error sanitization.""" + + @app.exception_handler(SecurityError) + async def security_error_handler(request: Request, exc: SecurityError): + """Handle security violations.""" + error_response = self.error_handler.sanitize_error_message(exc, ErrorLevel.HIGH) + + # Log security incident + logger.error( + "Security violation", + error=str(exc), + path=request.url.path, + method=request.method, + client_ip=self._get_client_ip(request), + user_agent=request.headers.get('user-agent', 'Unknown') + ) + + from fastapi.responses import JSONResponse + return JSONResponse( + status_code=403, + content=error_response + ) + + @app.exception_handler(ValueError) + async def validation_error_handler(request: Request, exc: ValueError): + """Handle validation errors.""" + error_response = self.error_handler.sanitize_error_message(exc, ErrorLevel.LOW) + + from fastapi.responses import JSONResponse + return JSONResponse( + status_code=400, + content=error_response + ) + + @app.exception_handler(HTTPException) + async def http_exception_handler(request: Request, exc: HTTPException): + """Handle HTTP exceptions.""" + error_response = self.error_handler.handle_http_exception( + exc.status_code, + exc.detail + ) + + from fastapi.responses import JSONResponse + return JSONResponse( + status_code=exc.status_code, + content=error_response + ) + + @app.exception_handler(Exception) + async def general_exception_handler(request: Request, exc: Exception): + """Handle all other exceptions.""" + error_response = self.error_handler.sanitize_error_message(exc, ErrorLevel.HIGH) + + # Log unexpected errors + logger.error( + "Unexpected error", + error=str(exc), + error_type=type(exc).__name__, + path=request.url.path, + method=request.method + ) + + from fastapi.responses import JSONResponse + return JSONResponse( + status_code=500, + content=error_response + ) + + def _get_client_ip(self, request: Request) -> str: + """Get client IP address.""" + forwarded_for = request.headers.get('x-forwarded-for') + if forwarded_for: + return forwarded_for.split(',')[0].strip() + return request.client.host if request.client else 'unknown' + + def get_security_headers(self) -> Dict[str, str]: + """Get recommended security headers for manual application.""" + return { + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "X-XSS-Protection": "1; mode=block", + "Strict-Transport-Security": f"max-age={self.hsts_max_age}; includeSubDomains", + "Content-Security-Policy": self.csp_policy, + "Referrer-Policy": "strict-origin-when-cross-origin", + "Permissions-Policy": "geolocation=(), microphone=(), camera=()" + } + + def validate_api_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Validate API request data with security checks.""" + from api.utils.validators import validate_operations, validate_secure_path + + validated_data = {} + + # Validate required fields + if 'input_path' not in request_data: + raise ValueError("Missing required field: input_path") + + if 'operations' not in request_data: + raise ValueError("Missing required field: operations") + + # Validate input path + try: + validated_data['input_path'] = validate_secure_path(request_data['input_path']) + except SecurityError as e: + raise SecurityError(f"Invalid input path: {e}") + + # Validate output path if provided + if 'output_path' in request_data: + try: + validated_data['output_path'] = validate_secure_path(request_data['output_path']) + except SecurityError as e: + raise SecurityError(f"Invalid output path: {e}") + + # Validate operations + try: + validated_data['operations'] = validate_operations(request_data['operations']) + except (ValueError, SecurityError) as e: + raise ValueError(f"Invalid operations: {e}") + + # Validate optional fields + if 'options' in request_data: + if not isinstance(request_data['options'], dict): + raise ValueError("Options must be a dictionary") + validated_data['options'] = request_data['options'] + + return validated_data + + +# Global security configuration instance +security_config = SecurityConfig() + + +def apply_security_to_app(app: FastAPI) -> FastAPI: + """Apply comprehensive security configuration to FastAPI app.""" + return security_config.configure_app(app) + + +def validate_request_data(data: Dict[str, Any]) -> Dict[str, Any]: + """Validate request data using security configuration.""" + return security_config.validate_api_request(data) + + +def get_security_info() -> Dict[str, Any]: + """Get current security configuration information.""" + return { + "security_enabled": True, + "rate_limiting": security_config.rate_limit_enabled, + "input_validation": True, + "error_sanitization": True, + "security_headers": True, + "audit_logging": True, + "debug_mode": security_config.debug_mode, + "environment": security_config.environment, + "max_body_size": security_config.max_body_size, + "rate_limit_calls": security_config.rate_limit_calls, + "rate_limit_period": security_config.rate_limit_period + } \ No newline at end of file diff --git a/api/utils/error_handler.py b/api/utils/error_handler.py new file mode 100644 index 0000000..38e14f5 --- /dev/null +++ b/api/utils/error_handler.py @@ -0,0 +1,304 @@ +""" +Production-safe error handling and message sanitization +""" +import re +import traceback +from typing import Dict, Any, Optional +from enum import Enum +import structlog + +logger = structlog.get_logger() + + +class ErrorLevel(Enum): + """Error severity levels""" + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + CRITICAL = "critical" + + +class ProductionErrorHandler: + """Handles error sanitization for production environments""" + + # Patterns that should be removed from error messages + SENSITIVE_PATTERNS = [ + r'/[a-zA-Z0-9_\-\.]+/[a-zA-Z0-9_\-\.]+/[a-zA-Z0-9_\-\.]+', # File paths + r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', # Email addresses + r'(?:password|secret|key|token)[\s=:]+[^\s]+', # Credentials + r'[A-Za-z0-9_]{32,}', # Long tokens/hashes + r'(?:https?://)[^\s]+', # URLs + r'(?:mongodb://|postgresql://|redis://)[^\s]+', # Database URLs + r'Bearer\s+[^\s]+', # Bearer tokens + r'Basic\s+[^\s]+', # Basic auth + r'(?:api[_-]?key|access[_-]?token)[\s=:]+[^\s]+', # API keys + r'(?:aws[_-]?access[_-]?key|aws[_-]?secret)[^\s]+', # AWS credentials + ] + + # Safe error messages for different error types + SAFE_ERROR_MESSAGES = { + 'FileNotFoundError': 'Requested file not found', + 'PermissionError': 'Access denied to requested resource', + 'ConnectionError': 'Service temporarily unavailable', + 'TimeoutError': 'Request timeout - please try again', + 'ValidationError': 'Invalid input provided', + 'SecurityError': 'Security validation failed', + 'FFmpegError': 'Video processing failed', + 'FFmpegCommandError': 'Invalid processing parameters', + 'FFmpegExecutionError': 'Video processing error occurred', + 'StorageError': 'Storage operation failed', + 'AuthenticationError': 'Authentication required', + 'AuthorizationError': 'Access denied', + 'RateLimitError': 'Rate limit exceeded', + 'DatabaseError': 'Database operation failed', + 'NetworkError': 'Network connectivity issue', + 'ConfigurationError': 'Service configuration error' + } + + def __init__(self, debug_mode: bool = False): + self.debug_mode = debug_mode + self.compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.SENSITIVE_PATTERNS] + + def sanitize_error_message(self, error: Exception, error_level: ErrorLevel = ErrorLevel.MEDIUM) -> Dict[str, Any]: + """ + Sanitize error message for production use. + + Args: + error: The exception to sanitize + error_level: Severity level of the error + + Returns: + Dict containing sanitized error information + """ + error_type = type(error).__name__ + original_message = str(error) + + # Get safe message based on error type + safe_message = self.SAFE_ERROR_MESSAGES.get(error_type, "An error occurred") + + # In debug mode, return more detailed information + if self.debug_mode and error_level in [ErrorLevel.LOW, ErrorLevel.MEDIUM]: + sanitized_message = self._sanitize_message_content(original_message) + return { + "error": { + "code": error_type.upper(), + "message": sanitized_message, + "type": error_type, + "level": error_level.value, + "debug_info": { + "original_message": sanitized_message, + "traceback": self._sanitize_traceback() + } + } + } + + # Production mode - return minimal safe information + error_code = self._generate_error_code(error_type) + + result = { + "error": { + "code": error_code, + "message": safe_message, + "type": error_type, + "level": error_level.value + } + } + + # Add helpful context for certain error types + if error_type == 'ValidationError': + result["error"]["details"] = "Please check your input parameters" + elif error_type in ['RateLimitError']: + result["error"]["details"] = "Please wait before making another request" + elif error_type in ['AuthenticationError', 'AuthorizationError']: + result["error"]["details"] = "Please check your credentials" + + # Log the actual error for debugging + logger.error( + "Error occurred", + error_type=error_type, + error_message=original_message, + error_level=error_level.value, + sanitized=True + ) + + return result + + def _sanitize_message_content(self, message: str) -> str: + """Remove sensitive information from error message content.""" + sanitized = message + + # Remove sensitive patterns + for pattern in self.compiled_patterns: + sanitized = pattern.sub('[REDACTED]', sanitized) + + # Remove common sensitive keywords + sensitive_keywords = [ + 'password', 'secret', 'key', 'token', 'credential', + 'username', 'email', 'phone', 'ssn', 'credit' + ] + + for keyword in sensitive_keywords: + # Replace sensitive values after keywords + pattern = rf'{keyword}[\s=:]+[^\s]+' + sanitized = re.sub(pattern, f'{keyword}=[REDACTED]', sanitized, flags=re.IGNORECASE) + + return sanitized + + def _sanitize_traceback(self) -> Optional[str]: + """Get sanitized traceback information.""" + if not self.debug_mode: + return None + + try: + tb = traceback.format_exc() + return self._sanitize_message_content(tb) + except Exception: + return "Traceback unavailable" + + def _generate_error_code(self, error_type: str) -> str: + """Generate consistent error codes.""" + error_codes = { + 'FileNotFoundError': 'FILE_NOT_FOUND', + 'PermissionError': 'ACCESS_DENIED', + 'ConnectionError': 'CONNECTION_FAILED', + 'TimeoutError': 'REQUEST_TIMEOUT', + 'ValidationError': 'VALIDATION_FAILED', + 'SecurityError': 'SECURITY_VIOLATION', + 'FFmpegError': 'PROCESSING_FAILED', + 'FFmpegCommandError': 'INVALID_PARAMETERS', + 'FFmpegExecutionError': 'PROCESSING_ERROR', + 'StorageError': 'STORAGE_FAILED', + 'AuthenticationError': 'AUTH_REQUIRED', + 'AuthorizationError': 'ACCESS_FORBIDDEN', + 'RateLimitError': 'RATE_LIMIT_EXCEEDED', + 'DatabaseError': 'DATABASE_ERROR', + 'NetworkError': 'NETWORK_ERROR', + 'ConfigurationError': 'CONFIG_ERROR' + } + + return error_codes.get(error_type, 'INTERNAL_ERROR') + + def handle_http_exception(self, status_code: int, detail: str = None) -> Dict[str, Any]: + """Handle HTTP exceptions with appropriate sanitization.""" + http_errors = { + 400: { + "code": "BAD_REQUEST", + "message": "Invalid request format or parameters", + "level": ErrorLevel.LOW.value + }, + 401: { + "code": "UNAUTHORIZED", + "message": "Authentication required", + "level": ErrorLevel.MEDIUM.value + }, + 403: { + "code": "FORBIDDEN", + "message": "Access denied", + "level": ErrorLevel.MEDIUM.value + }, + 404: { + "code": "NOT_FOUND", + "message": "Requested resource not found", + "level": ErrorLevel.LOW.value + }, + 422: { + "code": "VALIDATION_ERROR", + "message": "Request validation failed", + "level": ErrorLevel.LOW.value + }, + 429: { + "code": "RATE_LIMIT_EXCEEDED", + "message": "Too many requests", + "level": ErrorLevel.MEDIUM.value + }, + 500: { + "code": "INTERNAL_ERROR", + "message": "Internal server error", + "level": ErrorLevel.HIGH.value + }, + 502: { + "code": "BAD_GATEWAY", + "message": "Service temporarily unavailable", + "level": ErrorLevel.HIGH.value + }, + 503: { + "code": "SERVICE_UNAVAILABLE", + "message": "Service temporarily unavailable", + "level": ErrorLevel.HIGH.value + }, + 504: { + "code": "GATEWAY_TIMEOUT", + "message": "Request timeout", + "level": ErrorLevel.MEDIUM.value + } + } + + error_info = http_errors.get(status_code, { + "code": "HTTP_ERROR", + "message": "HTTP error occurred", + "level": ErrorLevel.MEDIUM.value + }) + + # Sanitize detail if provided + if detail and self.debug_mode: + error_info["details"] = self._sanitize_message_content(detail) + + return {"error": error_info} + + def create_security_alert(self, alert_type: str, details: Dict[str, Any]) -> Dict[str, Any]: + """Create security alert with sanitized information.""" + # Remove sensitive details for security alerts + safe_details = {} + allowed_fields = ['ip', 'user_agent', 'endpoint', 'method', 'timestamp'] + + for field in allowed_fields: + if field in details: + safe_details[field] = details[field] + + # Sanitize IP if needed (keep only first 3 octets for privacy) + if 'ip' in safe_details: + ip_parts = safe_details['ip'].split('.') + if len(ip_parts) == 4: + safe_details['ip'] = f"{ip_parts[0]}.{ip_parts[1]}.{ip_parts[2]}.xxx" + + logger.warning( + "Security alert", + alert_type=alert_type, + details=safe_details, + level=ErrorLevel.HIGH.value + ) + + return { + "error": { + "code": "SECURITY_VIOLATION", + "message": "Security policy violation detected", + "type": "SecurityError", + "level": ErrorLevel.HIGH.value, + "alert_type": alert_type + } + } + + +# Global error handler instance +error_handler = ProductionErrorHandler(debug_mode=False) + + +def set_debug_mode(enabled: bool): + """Enable or disable debug mode globally.""" + global error_handler + error_handler.debug_mode = enabled + + +def sanitize_error(error: Exception, level: ErrorLevel = ErrorLevel.MEDIUM) -> Dict[str, Any]: + """Convenience function for error sanitization.""" + return error_handler.sanitize_error_message(error, level) + + +def create_http_error(status_code: int, detail: str = None) -> Dict[str, Any]: + """Convenience function for HTTP error creation.""" + return error_handler.handle_http_exception(status_code, detail) + + +def create_security_alert(alert_type: str, details: Dict[str, Any]) -> Dict[str, Any]: + """Convenience function for security alert creation.""" + return error_handler.create_security_alert(alert_type, details) \ No newline at end of file diff --git a/api/utils/validators.py b/api/utils/validators.py index bc38805..ec875f2 100644 --- a/api/utils/validators.py +++ b/api/utils/validators.py @@ -1,6 +1,7 @@ """ -Input validation utilities +Input validation utilities with security enhancements """ +import os import re from pathlib import Path from typing import List, Dict, Any, Tuple @@ -24,17 +25,81 @@ ".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp", ".svg" } -# Regex patterns -PATH_REGEX = re.compile(r'^[a-zA-Z0-9\-_./]+$') +# Security patterns +SAFE_FILENAME_REGEX = re.compile(r'^[a-zA-Z0-9\-_]+(\.[a-zA-Z0-9]+)?$') CODEC_REGEX = re.compile(r'^[a-zA-Z0-9\-_]+$') +# Security configuration +ALLOWED_BASE_PATHS = { + '/storage', '/tmp/rendiff', '/app/uploads', '/app/temp' +} + +class SecurityError(Exception): + """Security validation error.""" + pass + + +def validate_secure_path(path: str, base_paths: set = None) -> str: + """ + Validate and sanitize file paths to prevent directory traversal. + + Args: + path: The path to validate + base_paths: Set of allowed base paths + + Returns: + Canonical path if valid + + Raises: + SecurityError: If path is unsafe + """ + if not path: + raise SecurityError("Path cannot be empty") + + if base_paths is None: + base_paths = ALLOWED_BASE_PATHS + + # Check for null bytes and dangerous characters + dangerous_chars = ['\x00', '|', ';', '&', '$', '`', '<', '>', '"', "'"] + for char in dangerous_chars: + if char in path: + raise SecurityError(f"Dangerous character detected in path: {char}") + + # Validate path length + if len(path) > 4096: + raise SecurityError("Path length exceeds maximum allowed") + + try: + # Get canonical path to resolve any traversal attempts + canonical_path = os.path.realpath(path) + + # Check if path is within allowed base paths + is_allowed = False + for base_path in base_paths: + base_canonical = os.path.realpath(base_path) + if canonical_path.startswith(base_canonical + os.sep) or canonical_path == base_canonical: + is_allowed = True + break + + if not is_allowed: + raise SecurityError(f"Path outside allowed directories: {path}") + + # Additional check for directory traversal patterns + if '..' in path or '~' in path: + raise SecurityError("Directory traversal attempt detected") + + return canonical_path + + except OSError as e: + raise SecurityError(f"Invalid path: {e}") + async def validate_input_path( path: str, storage_service: StorageService ) -> Tuple[str, str]: """ - Validate input file path. + Validate input file path with security checks. Returns: (backend_name, validated_path) """ if not path: @@ -47,6 +112,18 @@ async def validate_input_path( if backend_name not in storage_service.backends: raise ValueError(f"Unknown storage backend: {backend_name}") + # Security validation for local paths + if backend_name == 'local': + try: + file_path = validate_secure_path(file_path) + except SecurityError as e: + raise ValueError(f"Security validation failed: {e}") + + # Validate filename components + filename = Path(file_path).name + if not SAFE_FILENAME_REGEX.match(filename): + raise ValueError(f"Invalid filename format: {filename}") + # Validate file extension file_ext = Path(file_path).suffix.lower() if file_ext not in (ALLOWED_VIDEO_EXTENSIONS | ALLOWED_AUDIO_EXTENSIONS): @@ -65,7 +142,7 @@ async def validate_output_path( storage_service: StorageService ) -> Tuple[str, str]: """ - Validate output file path. + Validate output file path with security checks. Returns: (backend_name, validated_path) """ if not path: @@ -78,15 +155,29 @@ async def validate_output_path( if backend_name not in storage_service.backends: raise ValueError(f"Unknown storage backend: {backend_name}") + # Security validation for local paths + if backend_name == 'local': + try: + file_path = validate_secure_path(file_path) + except SecurityError as e: + raise ValueError(f"Security validation failed: {e}") + + # Validate filename components + filename = Path(file_path).name + if not SAFE_FILENAME_REGEX.match(filename): + raise ValueError(f"Invalid filename format: {filename}") + # Check if backend allows output storage_config = storage_service.config output_backends = storage_config.get("policies", {}).get("output_backends", []) if output_backends and backend_name not in output_backends: raise ValueError(f"Backend '{backend_name}' not allowed for output") - # Validate path format - if not PATH_REGEX.match(file_path): - raise ValueError(f"Invalid output path format: {file_path}") + # Validate file extension for output + file_ext = Path(file_path).suffix.lower() + allowed_output_extensions = ALLOWED_VIDEO_EXTENSIONS | ALLOWED_AUDIO_EXTENSIONS | ALLOWED_IMAGE_EXTENSIONS + if file_ext and file_ext not in allowed_output_extensions: + raise ValueError(f"Unsupported output file type: {file_ext}") # Ensure directory exists backend = storage_service.backends[backend_name] @@ -97,15 +188,32 @@ async def validate_output_path( def validate_operations(operations: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Validate and normalize operations list.""" + """Validate and normalize operations list with enhanced security checks.""" + if not operations: + raise ValueError("Operations list cannot be empty") + + if len(operations) > 50: # Prevent DOS through too many operations + raise ValueError("Too many operations specified (maximum 50)") + validated = [] - for op in operations: + for i, op in enumerate(operations): + if not isinstance(op, dict): + raise ValueError(f"Operation {i} must be a dictionary") + if "type" not in op: - raise ValueError("Operation missing 'type' field") + raise ValueError(f"Operation {i} missing 'type' field") op_type = op["type"] + # Validate operation type + if not isinstance(op_type, str): + raise ValueError(f"Operation {i} type must be a string") + + # Check for command injection in operation type + if not re.match(r'^[a-zA-Z_]+$', op_type): + raise SecurityError(f"Invalid operation type format: {op_type}") + if op_type == "trim": validated_op = validate_trim_operation(op) elif op_type == "watermark": @@ -114,6 +222,8 @@ def validate_operations(operations: List[Dict[str, Any]]) -> List[Dict[str, Any] validated_op = validate_filter_operation(op) elif op_type == "stream": validated_op = validate_stream_operation(op) + elif op_type == "transcode": + validated_op = validate_transcode_operation(op) else: raise ValueError(f"Unknown operation type: {op_type}") @@ -123,34 +233,55 @@ def validate_operations(operations: List[Dict[str, Any]]) -> List[Dict[str, Any] def validate_trim_operation(op: Dict[str, Any]) -> Dict[str, Any]: - """Validate trim operation.""" + """Validate trim operation with enhanced security checks.""" validated = {"type": "trim"} # Validate start time if "start" in op: start = op["start"] if isinstance(start, (int, float)): + if start < 0 or start > 86400: # Max 24 hours + raise ValueError("Start time out of valid range (0-86400 seconds)") validated["start"] = float(start) elif isinstance(start, str): + if len(start) > 20: # Reasonable length limit + raise ValueError("Start time string too long") validated["start"] = parse_time_string(start) else: - raise ValueError("Invalid start time format") + raise ValueError("Invalid start time format - must be number or time string") # Validate duration or end time if "duration" in op: duration = op["duration"] if isinstance(duration, (int, float)): + if duration <= 0 or duration > 86400: # Max 24 hours + raise ValueError("Duration out of valid range (0-86400 seconds)") validated["duration"] = float(duration) + elif isinstance(duration, str): + if len(duration) > 20: + raise ValueError("Duration string too long") + parsed_duration = parse_time_string(duration) + if parsed_duration <= 0: + raise ValueError("Duration must be positive") + validated["duration"] = parsed_duration else: - raise ValueError("Invalid duration format") + raise ValueError("Invalid duration format - must be number or time string") elif "end" in op: end = op["end"] if isinstance(end, (int, float)): + if end < 0 or end > 86400: # Max 24 hours + raise ValueError("End time out of valid range (0-86400 seconds)") validated["end"] = float(end) elif isinstance(end, str): + if len(end) > 20: + raise ValueError("End time string too long") validated["end"] = parse_time_string(end) else: - raise ValueError("Invalid end time format") + raise ValueError("Invalid end time format - must be number or time string") + + # Validate that we have at least duration or end time + if "start" in validated and "duration" not in validated and "end" not in validated: + raise ValueError("Trim operation requires either duration or end time when start is specified") return validated @@ -204,14 +335,159 @@ def validate_stream_operation(op: Dict[str, Any]) -> Dict[str, Any]: } +def validate_transcode_operation(op: Dict[str, Any]) -> Dict[str, Any]: + """Validate transcode operation with enhanced security checks.""" + validated = {"type": "transcode"} + + # Allowed video codecs + ALLOWED_VIDEO_CODECS = {'h264', 'h265', 'hevc', 'vp8', 'vp9', 'av1', 'libx264', 'libx265', 'copy'} + ALLOWED_AUDIO_CODECS = {'aac', 'mp3', 'opus', 'vorbis', 'ac3', 'libfdk_aac', 'copy'} + ALLOWED_PRESETS = {'ultrafast', 'superfast', 'veryfast', 'faster', 'fast', 'medium', 'slow', 'slower', 'veryslow'} + + # Validate video codec + if "video_codec" in op: + codec = op["video_codec"] + if not isinstance(codec, str): + raise ValueError("Video codec must be a string") + if codec not in ALLOWED_VIDEO_CODECS: + raise ValueError(f"Invalid video codec: {codec}") + validated["video_codec"] = codec + + # Validate audio codec + if "audio_codec" in op: + codec = op["audio_codec"] + if not isinstance(codec, str): + raise ValueError("Audio codec must be a string") + if codec not in ALLOWED_AUDIO_CODECS: + raise ValueError(f"Invalid audio codec: {codec}") + validated["audio_codec"] = codec + + # Validate preset + if "preset" in op: + preset = op["preset"] + if not isinstance(preset, str): + raise ValueError("Preset must be a string") + if preset not in ALLOWED_PRESETS: + raise ValueError(f"Invalid preset: {preset}") + validated["preset"] = preset + + # Validate bitrates + if "video_bitrate" in op: + validated["video_bitrate"] = validate_bitrate(op["video_bitrate"]) + if "audio_bitrate" in op: + validated["audio_bitrate"] = validate_bitrate(op["audio_bitrate"]) + + # Validate resolution + if "width" in op or "height" in op: + width = op.get("width") + height = op.get("height") + validated_resolution = validate_resolution(width, height) + if validated_resolution: + validated.update(validated_resolution) + + # Validate frame rate + if "fps" in op: + fps = op["fps"] + if isinstance(fps, (int, float)): + if fps <= 0 or fps > 120: # Reasonable FPS limits + raise ValueError("FPS out of valid range (1-120)") + validated["fps"] = float(fps) + else: + raise ValueError("FPS must be a number") + + # Validate CRF + if "crf" in op: + crf = op["crf"] + if isinstance(crf, (int, float)): + if crf < 0 or crf > 51: # Standard CRF range + raise ValueError("CRF out of valid range (0-51)") + validated["crf"] = int(crf) + else: + raise ValueError("CRF must be a number") + + return validated + + +def validate_bitrate(bitrate) -> str: + """Validate bitrate parameter with security checks.""" + if isinstance(bitrate, str): + # Validate bitrate format + if not re.match(r'^\d+[kKmM]?$', bitrate): + raise ValueError(f"Invalid bitrate format: {bitrate}") + + # Parse and validate range + if bitrate.lower().endswith('k'): + value = int(bitrate[:-1]) * 1000 + elif bitrate.lower().endswith('m'): + value = int(bitrate[:-1]) * 1000000 + else: + value = int(bitrate) + + # Check reasonable limits (100 kbps to 50 Mbps) + if value < 100000 or value > 50000000: + raise ValueError("Bitrate out of reasonable range (100k-50M)") + + return bitrate + elif isinstance(bitrate, (int, float)): + value = int(bitrate) + if value < 100000 or value > 50000000: + raise ValueError("Bitrate out of reasonable range (100000-50000000)") + return str(value) + else: + raise ValueError("Bitrate must be string or number") + + +def validate_resolution(width, height) -> Dict[str, int]: + """Validate video resolution parameters.""" + result = {} + + if width is not None: + if not isinstance(width, (int, float)): + raise ValueError("Width must be a number") + width = int(width) + if width < 32 or width > 7680: # Min 32px, max 8K width + raise ValueError("Width out of valid range (32-7680)") + if width % 2 != 0: # Must be even for most codecs + raise ValueError("Width must be even number") + result["width"] = width + + if height is not None: + if not isinstance(height, (int, float)): + raise ValueError("Height must be a number") + height = int(height) + if height < 32 or height > 4320: # Min 32px, max 8K height + raise ValueError("Height out of valid range (32-4320)") + if height % 2 != 0: # Must be even for most codecs + raise ValueError("Height must be even number") + result["height"] = height + + return result + + def parse_time_string(time_str: str) -> float: - """Parse time string in format HH:MM:SS.ms to seconds.""" + """Parse time string in format HH:MM:SS.ms to seconds with validation.""" + if not isinstance(time_str, str): + raise ValueError("Time string must be a string") + + # Security check for time string format + if not re.match(r'^(\d{1,2}:)?(\d{1,2}:)?\d{1,2}(\.\d{1,3})?$', time_str): + raise ValueError(f"Invalid time format: {time_str}") + parts = time_str.split(":") - if len(parts) == 1: - return float(parts[0]) - elif len(parts) == 2: - return float(parts[0]) * 60 + float(parts[1]) - elif len(parts) == 3: - return float(parts[0]) * 3600 + float(parts[1]) * 60 + float(parts[2]) - else: - raise ValueError(f"Invalid time format: {time_str}") \ No newline at end of file + try: + if len(parts) == 1: + seconds = float(parts[0]) + elif len(parts) == 2: + seconds = float(parts[0]) * 60 + float(parts[1]) + elif len(parts) == 3: + seconds = float(parts[0]) * 3600 + float(parts[1]) * 60 + float(parts[2]) + else: + raise ValueError(f"Invalid time format: {time_str}") + + # Validate reasonable time bounds + if seconds < 0 or seconds > 86400: # 24 hours max + raise ValueError(f"Time out of reasonable range: {seconds}") + + return seconds + except ValueError as e: + raise ValueError(f"Invalid time format: {time_str} - {e}") \ No newline at end of file diff --git a/tests/test_security_fixes.py b/tests/test_security_fixes.py new file mode 100644 index 0000000..0d8d761 --- /dev/null +++ b/tests/test_security_fixes.py @@ -0,0 +1,282 @@ +""" +Tests for security fixes implementation +""" +import pytest +import os +import tempfile +from unittest.mock import patch, MagicMock + +# Test FFmpeg command injection fix +def test_ffmpeg_command_injection_prevention(): + """Test that FFmpeg command builder prevents injection attacks.""" + from worker.utils.ffmpeg import FFmpegCommandBuilder, FFmpegCommandError + + builder = FFmpegCommandBuilder() + + # Test dangerous characters in paths + with pytest.raises(FFmpegCommandError, match="Dangerous character detected"): + builder._validate_paths("/path/to/input.mp4", "/output; rm -rf /") + + with pytest.raises(FFmpegCommandError, match="Dangerous character detected"): + builder._validate_paths("/input`whoami`.mp4", "/output.mp4") + + # Test valid paths should pass + try: + builder._validate_paths("/valid/input.mp4", "/valid/output.mp4") + assert True # Should not raise exception + except FFmpegCommandError: + pytest.fail("Valid paths should not raise validation error") + + +def test_ffmpeg_parameter_validation(): + """Test FFmpeg parameter validation.""" + from worker.utils.ffmpeg import FFmpegCommandBuilder, FFmpegCommandError + + builder = FFmpegCommandBuilder() + + # Test invalid codec + with pytest.raises(FFmpegCommandError, match="Invalid video codec"): + builder._validate_transcode_params({"video_codec": "malicious_codec"}) + + # Test valid codec + try: + builder._validate_transcode_params({"video_codec": "h264"}) + assert True + except FFmpegCommandError: + pytest.fail("Valid codec should not raise error") + + # Test CRF out of range + with pytest.raises(FFmpegCommandError, match="out of range"): + builder._validate_transcode_params({"crf": 100}) + + +def test_path_traversal_prevention(): + """Test that path traversal attacks are prevented.""" + from api.utils.validators import validate_secure_path, SecurityError + + # Test directory traversal attempts + with pytest.raises(SecurityError, match="Directory traversal"): + validate_secure_path("../../../etc/passwd") + + with pytest.raises(SecurityError, match="Directory traversal"): + validate_secure_path("/storage/../../../etc/passwd") + + # Test null byte injection + with pytest.raises(SecurityError, match="Dangerous character"): + validate_secure_path("/storage/file\x00.txt") + + # Test command injection + with pytest.raises(SecurityError, match="Dangerous character"): + validate_secure_path("/storage/file; rm -rf /") + + +def test_input_validation_operations(): + """Test enhanced operation validation.""" + from api.utils.validators import validate_operations, SecurityError + + # Test too many operations (DOS prevention) + large_ops = [{"type": "trim", "start": 0}] * 100 + with pytest.raises(ValueError, match="Too many operations"): + validate_operations(large_ops) + + # Test invalid operation type format + with pytest.raises(SecurityError, match="Invalid operation type format"): + validate_operations([{"type": "trim; rm -rf /"}]) + + # Test valid operations + valid_ops = [ + {"type": "trim", "start": 10, "duration": 30}, + {"type": "transcode", "video_codec": "h264"} + ] + result = validate_operations(valid_ops) + assert len(result) == 2 + assert result[0]["type"] == "trim" + + +def test_rate_limiting_middleware(): + """Test rate limiting middleware functionality.""" + from api.middleware.security import RateLimitMiddleware, APIKeyQuota + from starlette.applications import Starlette + from starlette.requests import Request + from starlette.responses import Response + from unittest.mock import AsyncMock + + app = Starlette() + middleware = RateLimitMiddleware(app, calls=2, period=3600, enabled=True) + + # Test quota retrieval + quota = middleware._get_client_quota("basic_test_key") + assert isinstance(quota, APIKeyQuota) + assert quota.calls_per_hour == 500 # Basic tier + + # Test enterprise key + enterprise_quota = middleware._get_client_quota("ent_test_key") + assert enterprise_quota.calls_per_hour == 10000 # Enterprise tier + + +def test_error_message_sanitization(): + """Test error message sanitization.""" + from api.utils.error_handler import ProductionErrorHandler, ErrorLevel + + handler = ProductionErrorHandler(debug_mode=False) + + # Test sensitive information removal + error = Exception("Database error at postgresql://user:password@host:5432/db") + result = handler.sanitize_error_message(error, ErrorLevel.HIGH) + + # Should not contain sensitive information + assert "password" not in str(result) + assert "postgresql://" not in str(result) + assert result["error"]["message"] == "An error occurred" + + # Test debug mode + debug_handler = ProductionErrorHandler(debug_mode=True) + debug_result = debug_handler.sanitize_error_message(error, ErrorLevel.LOW) + assert "debug_info" in debug_result["error"] + + +def test_security_middleware_headers(): + """Test security headers middleware.""" + from api.middleware.security import SecurityHeadersMiddleware + from starlette.applications import Starlette + from starlette.testclient import TestClient + + app = Starlette() + + @app.route("/test") + async def test_endpoint(request): + from starlette.responses import JSONResponse + return JSONResponse({"message": "test"}) + + app.add_middleware(SecurityHeadersMiddleware) + + client = TestClient(app) + response = client.get("/test") + + # Check security headers + assert "X-Content-Type-Options" in response.headers + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert "X-Frame-Options" in response.headers + assert "Content-Security-Policy" in response.headers + + +def test_input_sanitization_middleware(): + """Test input sanitization middleware.""" + from api.middleware.security import InputSanitizationMiddleware + from starlette.applications import Starlette + from starlette.testclient import TestClient + + app = Starlette() + + @app.route("/test", methods=["POST"]) + async def test_endpoint(request): + from starlette.responses import JSONResponse + return JSONResponse({"message": "test"}) + + app.add_middleware(InputSanitizationMiddleware, max_body_size=1024) + + client = TestClient(app) + + # Test content type validation + response = client.post("/test", + data="test", + headers={"Content-Type": "text/plain"}) + assert response.status_code == 415 # Unsupported Media Type + + # Test valid content type + response = client.post("/test", + json={"data": "test"}) + assert response.status_code == 200 + + +def test_security_audit_middleware(): + """Test security audit middleware.""" + from api.middleware.security import SecurityAuditMiddleware + from starlette.applications import Starlette + from starlette.testclient import TestClient + import structlog + from io import StringIO + + app = Starlette() + + @app.route("/test") + async def test_endpoint(request): + from starlette.responses import JSONResponse + return JSONResponse({"message": "test"}) + + app.add_middleware(SecurityAuditMiddleware) + + client = TestClient(app) + + # Test normal request + response = client.get("/test") + assert response.status_code == 200 + + +def test_comprehensive_security_config(): + """Test the comprehensive security configuration.""" + from api.security_config import SecurityConfig, validate_request_data + + config = SecurityConfig() + + # Test valid request data + valid_data = { + "input_path": "/storage/test.mp4", + "output_path": "/storage/output.mp4", + "operations": [ + {"type": "trim", "start": 10, "duration": 30} + ] + } + + # This should work with our security fixes + try: + result = validate_request_data(valid_data) + assert "input_path" in result + assert "operations" in result + assert len(result["operations"]) == 1 + except Exception as e: + # If validation fails, it should be due to path not being in allowed base paths + # which is expected in test environment + assert "Path outside allowed directories" in str(e) + + +# Integration test +def test_end_to_end_security(): + """Test that all security components work together.""" + from api.security_config import apply_security_to_app, get_security_info + from fastapi import FastAPI + from fastapi.testclient import TestClient + + app = FastAPI() + + @app.get("/health") + async def health(): + return {"status": "healthy"} + + @app.post("/api/v1/process") + async def process_video(): + return {"message": "processed"} + + # Apply security configuration + app = apply_security_to_app(app) + + client = TestClient(app) + + # Test health endpoint (should work) + response = client.get("/health") + assert response.status_code == 200 + + # Check security info + security_info = get_security_info() + assert security_info["security_enabled"] is True + assert "rate_limiting" in security_info + assert "input_validation" in security_info + + +if __name__ == "__main__": + # Run basic validation tests + test_ffmpeg_command_injection_prevention() + test_path_traversal_prevention() + test_input_validation_operations() + test_error_message_sanitization() + print("✅ All security fix tests passed!") \ No newline at end of file diff --git a/worker/utils/ffmpeg.py b/worker/utils/ffmpeg.py index 010bf23..159b7fa 100644 --- a/worker/utils/ffmpeg.py +++ b/worker/utils/ffmpeg.py @@ -120,20 +120,59 @@ def get_best_encoder(codec: str, hardware_caps: Dict[str, bool]) -> str: class FFmpegCommandBuilder: - """Build FFmpeg commands from operations and options.""" + """Build FFmpeg commands from operations and options with security validation.""" + + # Security whitelists for command injection prevention + ALLOWED_CODECS = { + 'video': {'h264', 'h265', 'hevc', 'vp8', 'vp9', 'av1', 'libx264', 'libx265', 'copy'}, + 'audio': {'aac', 'mp3', 'opus', 'vorbis', 'ac3', 'libfdk_aac', 'copy'} + } + + ALLOWED_FILTERS = { + 'scale', 'crop', 'overlay', 'eq', 'hqdn3d', 'unsharp', 'format', 'colorchannelmixer' + } + + ALLOWED_PRESETS = { + 'ultrafast', 'superfast', 'veryfast', 'faster', 'fast', 'medium', + 'slow', 'slower', 'veryslow', 'placebo' + } + + ALLOWED_PIXEL_FORMATS = { + 'yuv420p', 'yuv422p', 'yuv444p', 'rgb24', 'rgba', 'bgr24', 'bgra' + } + + # Safe parameter ranges + SAFE_RANGES = { + 'crf': (0, 51), + 'bitrate_min': 100, # 100 kbps minimum + 'bitrate_max': 50000, # 50 Mbps maximum + 'fps_min': 1, + 'fps_max': 120, + 'width_min': 32, + 'width_max': 7680, # 8K max + 'height_min': 32, + 'height_max': 4320, # 8K max + 'threads_max': 64 + } def __init__(self, hardware_caps: Optional[Dict[str, bool]] = None): self.hardware_caps = hardware_caps or {} + logger.info("FFmpegCommandBuilder initialized with security validation") def build_command(self, input_path: str, output_path: str, options: Dict[str, Any], operations: List[Dict[str, Any]]) -> List[str]: - """Build complete FFmpeg command from operations.""" + """Build complete FFmpeg command from operations with security validation.""" + # Validate all inputs first + self._validate_paths(input_path, output_path) + self._validate_options(options) + self._validate_operations(operations) + cmd = ['ffmpeg', '-y'] # -y to overwrite output files # Add hardware acceleration if available cmd.extend(self._add_hardware_acceleration()) - # Add input + # Add input (already validated) cmd.extend(['-i', input_path]) # Add operations @@ -166,10 +205,10 @@ def build_command(self, input_path: str, output_path: str, # Add global options cmd.extend(self._handle_global_options(options)) - # Add output + # Add output (already validated) cmd.append(output_path) - logger.info("Built FFmpeg command", command=' '.join(cmd)) + logger.info("Built secure FFmpeg command", command=' '.join(cmd)) return cmd def _add_hardware_acceleration(self) -> List[str]: @@ -184,6 +223,214 @@ def _add_hardware_acceleration(self) -> List[str]: return ['-hwaccel', 'videotoolbox'] return [] + def _validate_paths(self, input_path: str, output_path: str): + """Validate input and output paths for security.""" + import os + + # Check for null bytes and dangerous characters + dangerous_chars = ['\x00', '|', ';', '&', '$', '`', '(', ')', '<', '>', '"', "'"] + for path in [input_path, output_path]: + for char in dangerous_chars: + if char in path: + raise FFmpegCommandError(f"Dangerous character detected in path: {char}") + + # Validate path length + if len(input_path) > 4096 or len(output_path) > 4096: + raise FFmpegCommandError("Path length exceeds maximum allowed") + + # Ensure paths are absolute and normalized + try: + input_normalized = os.path.normpath(input_path) + output_normalized = os.path.normpath(output_path) + + # Check for directory traversal attempts + if '..' in input_normalized or '..' in output_normalized: + raise FFmpegCommandError("Directory traversal attempt detected") + + except Exception as e: + raise FFmpegCommandError(f"Path validation failed: {e}") + + def _validate_options(self, options: Dict[str, Any]): + """Validate global options for security.""" + if not isinstance(options, dict): + raise FFmpegCommandError("Options must be a dictionary") + + # Validate each option + for key, value in options.items(): + if not isinstance(key, str): + raise FFmpegCommandError("Option keys must be strings") + + # Check for command injection in option values + if isinstance(value, str): + self._validate_string_parameter(value, f"option_{key}") + + def _validate_operations(self, operations: List[Dict[str, Any]]): + """Validate operations list for security.""" + if not isinstance(operations, list): + raise FFmpegCommandError("Operations must be a list") + + allowed_operation_types = {'transcode', 'trim', 'watermark', 'filter', 'stream_map'} + + for i, operation in enumerate(operations): + if not isinstance(operation, dict): + raise FFmpegCommandError(f"Operation {i} must be a dictionary") + + op_type = operation.get('type') + if op_type not in allowed_operation_types: + raise FFmpegCommandError(f"Unknown operation type: {op_type}") + + # Validate operation parameters + params = operation.get('params', {}) + if not isinstance(params, dict): + raise FFmpegCommandError(f"Operation {i} params must be a dictionary") + + self._validate_operation_params(op_type, params) + + def _validate_operation_params(self, op_type: str, params: Dict[str, Any]): + """Validate operation-specific parameters.""" + if op_type == 'transcode': + self._validate_transcode_params(params) + elif op_type == 'trim': + self._validate_trim_params(params) + elif op_type == 'filter': + self._validate_filter_params(params) + elif op_type == 'watermark': + self._validate_watermark_params(params) + + def _validate_transcode_params(self, params: Dict[str, Any]): + """Validate transcoding parameters.""" + if 'video_codec' in params: + codec = params['video_codec'] + if codec not in self.ALLOWED_CODECS['video']: + raise FFmpegCommandError(f"Invalid video codec: {codec}") + + if 'audio_codec' in params: + codec = params['audio_codec'] + if codec not in self.ALLOWED_CODECS['audio']: + raise FFmpegCommandError(f"Invalid audio codec: {codec}") + + if 'preset' in params: + preset = params['preset'] + if preset not in self.ALLOWED_PRESETS: + raise FFmpegCommandError(f"Invalid preset: {preset}") + + # Validate numeric parameters + self._validate_numeric_param(params.get('crf'), 'crf', self.SAFE_RANGES['crf']) + self._validate_bitrate(params.get('video_bitrate'), 'video_bitrate') + self._validate_bitrate(params.get('audio_bitrate'), 'audio_bitrate') + self._validate_numeric_param(params.get('fps'), 'fps', (self.SAFE_RANGES['fps_min'], self.SAFE_RANGES['fps_max'])) + self._validate_resolution(params.get('width'), params.get('height')) + + def _validate_trim_params(self, params: Dict[str, Any]): + """Validate trim parameters.""" + for time_param in ['start_time', 'duration', 'end_time']: + if time_param in params: + value = params[time_param] + if isinstance(value, (int, float)): + if value < 0 or value > 86400: # Max 24 hours + raise FFmpegCommandError(f"Invalid {time_param}: {value}") + elif isinstance(value, str): + self._validate_time_string(value, time_param) + + def _validate_filter_params(self, params: Dict[str, Any]): + """Validate filter parameters.""" + for key, value in params.items(): + if isinstance(value, str): + self._validate_string_parameter(value, f"filter_{key}") + elif isinstance(value, (int, float)): + if abs(value) > 1000: # Reasonable limit for filter values + raise FFmpegCommandError(f"Filter parameter {key} out of range: {value}") + + def _validate_watermark_params(self, params: Dict[str, Any]): + """Validate watermark parameters.""" + # Validate position values + for pos_param in ['x', 'y']: + if pos_param in params: + value = params[pos_param] + if isinstance(value, str): + self._validate_string_parameter(value, f"watermark_{pos_param}") + + # Validate opacity + if 'opacity' in params: + opacity = params['opacity'] + if not isinstance(opacity, (int, float)) or opacity < 0 or opacity > 1: + raise FFmpegCommandError(f"Invalid opacity: {opacity}") + + def _validate_string_parameter(self, value: str, param_name: str): + """Validate string parameters for command injection.""" + if not isinstance(value, str): + return + + # Check for command injection patterns + dangerous_patterns = [ + ';', '|', '&', '$', '`', '$(', '${', '<(', '>(', '\n', '\r' + ] + + for pattern in dangerous_patterns: + if pattern in value: + raise FFmpegCommandError(f"Dangerous pattern in {param_name}: {pattern}") + + # Check length + if len(value) > 1024: + raise FFmpegCommandError(f"Parameter {param_name} too long") + + def _validate_numeric_param(self, value, param_name: str, valid_range: tuple): + """Validate numeric parameters.""" + if value is None: + return + + if not isinstance(value, (int, float)): + raise FFmpegCommandError(f"Parameter {param_name} must be numeric") + + min_val, max_val = valid_range + if value < min_val or value > max_val: + raise FFmpegCommandError(f"Parameter {param_name} out of range [{min_val}, {max_val}]: {value}") + + def _validate_bitrate(self, bitrate, param_name: str): + """Validate bitrate parameters.""" + if bitrate is None: + return + + if isinstance(bitrate, str): + # Parse bitrate strings like "1000k", "5M" + import re + match = re.match(r'^(\d+)([kKmM]?)$', bitrate) + if not match: + raise FFmpegCommandError(f"Invalid bitrate format: {bitrate}") + + value, unit = match.groups() + value = int(value) + + if unit.lower() == 'k': + value *= 1000 + elif unit.lower() == 'm': + value *= 1000000 + + if value < self.SAFE_RANGES['bitrate_min'] or value > self.SAFE_RANGES['bitrate_max']: + raise FFmpegCommandError(f"Bitrate out of safe range: {bitrate}") + elif isinstance(bitrate, (int, float)): + if bitrate < self.SAFE_RANGES['bitrate_min'] or bitrate > self.SAFE_RANGES['bitrate_max']: + raise FFmpegCommandError(f"Bitrate out of safe range: {bitrate}") + + def _validate_resolution(self, width, height): + """Validate resolution parameters.""" + if width is not None: + self._validate_numeric_param(width, 'width', + (self.SAFE_RANGES['width_min'], self.SAFE_RANGES['width_max'])) + + if height is not None: + self._validate_numeric_param(height, 'height', + (self.SAFE_RANGES['height_min'], self.SAFE_RANGES['height_max'])) + + def _validate_time_string(self, time_str: str, param_name: str): + """Validate time string format.""" + import re + + # Allow formats: HH:MM:SS, MM:SS, SS, HH:MM:SS.ms + time_pattern = r'^(\d{1,2}:)?(\d{1,2}:)?\d{1,2}(\.\d{1,3})?$' + if not re.match(time_pattern, time_str): + raise FFmpegCommandError(f"Invalid time format for {param_name}: {time_str}") + def _handle_transcode(self, params: Dict[str, Any]) -> List[str]: """Handle video transcoding parameters.""" cmd_parts = []