diff --git a/alembic/versions/002_add_api_key_table.py b/alembic/versions/002_add_api_key_table.py new file mode 100644 index 0000000..0d1d5bc --- /dev/null +++ b/alembic/versions/002_add_api_key_table.py @@ -0,0 +1,70 @@ +"""Add API Key table + +Revision ID: 002_add_api_key_table +Revises: 001_initial_schema +Create Date: 2025-07-11 09:30:00.000000 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '002_add_api_key_table' +down_revision = '001_initial_schema' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Create API keys table.""" + # Create api_keys table + op.create_table( + 'api_keys', + sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('key_hash', sa.String(length=64), nullable=False), + sa.Column('key_prefix', sa.String(length=8), nullable=False), + sa.Column('user_id', sa.String(length=255), nullable=True), + sa.Column('organization', sa.String(length=255), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False, default=True), + sa.Column('is_admin', sa.Boolean(), nullable=False, default=False), + sa.Column('max_concurrent_jobs', sa.Integer(), nullable=False, default=5), + sa.Column('monthly_limit_minutes', sa.Integer(), nullable=False, default=10000), + sa.Column('total_requests', sa.Integer(), nullable=False, default=0), + sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('revoked_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('created_by', sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + # Create indexes + op.create_index('ix_api_keys_key_hash', 'api_keys', ['key_hash'], unique=True) + op.create_index('ix_api_keys_key_prefix', 'api_keys', ['key_prefix']) + op.create_index('ix_api_keys_user_id', 'api_keys', ['user_id']) + op.create_index('ix_api_keys_organization', 'api_keys', ['organization']) + op.create_index('ix_api_keys_is_active', 'api_keys', ['is_active']) + op.create_index('ix_api_keys_created_at', 'api_keys', ['created_at']) + op.create_index('ix_api_keys_expires_at', 'api_keys', ['expires_at']) + + # Add composite index for common queries + op.create_index('ix_api_keys_active_lookup', 'api_keys', ['is_active', 'revoked_at', 'expires_at']) + + +def downgrade() -> None: + """Drop API keys table.""" + # Drop indexes + op.drop_index('ix_api_keys_active_lookup', table_name='api_keys') + op.drop_index('ix_api_keys_expires_at', table_name='api_keys') + op.drop_index('ix_api_keys_created_at', table_name='api_keys') + op.drop_index('ix_api_keys_is_active', table_name='api_keys') + op.drop_index('ix_api_keys_organization', table_name='api_keys') + op.drop_index('ix_api_keys_user_id', table_name='api_keys') + op.drop_index('ix_api_keys_key_prefix', table_name='api_keys') + op.drop_index('ix_api_keys_key_hash', table_name='api_keys') + + # Drop table + op.drop_table('api_keys') \ No newline at end of file diff --git a/api/dependencies.py b/api/dependencies.py index 249d0a6..e2b0b9c 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -36,6 +36,7 @@ async def get_api_key( async def require_api_key( request: Request, api_key: Optional[str] = Depends(get_api_key), + db: AsyncSession = Depends(get_db), ) -> str: """Require valid API key for endpoint access.""" if not settings.ENABLE_API_KEYS: @@ -48,9 +49,19 @@ async def require_api_key( headers={"WWW-Authenticate": "Bearer"}, ) - # In production, validate against database - # For now, accept any non-empty key - if not api_key.strip(): + # Validate API key against database + from api.services.api_key import APIKeyService + + api_key_model = await APIKeyService.validate_api_key( + db, api_key, update_usage=True + ) + + if not api_key_model: + logger.warning( + "Invalid API key attempted", + api_key_prefix=api_key[:8] + "..." if len(api_key) > 8 else api_key, + client_ip=request.client.host, + ) raise HTTPException( status_code=401, detail="Invalid API key", @@ -58,34 +69,77 @@ async def require_api_key( # Check IP whitelist if enabled if settings.ENABLE_IP_WHITELIST: + import ipaddress client_ip = request.client.host - if not any(client_ip.startswith(ip) for ip in settings.ip_whitelist_parsed): + + # Validate client IP against CIDR ranges + client_ip_obj = ipaddress.ip_address(client_ip) + allowed = False + + for allowed_range in settings.ip_whitelist_parsed: + try: + if client_ip_obj in ipaddress.ip_network(allowed_range, strict=False): + allowed = True + break + except (ipaddress.AddressValueError, ipaddress.NetmaskValueError): + # Fallback to string comparison for invalid CIDR + if client_ip.startswith(allowed_range): + allowed = True + break + + if not allowed: logger.warning( "IP not in whitelist", client_ip=client_ip, - api_key=api_key[:8] + "...", + api_key_id=str(api_key_model.id), + user_id=api_key_model.user_id, ) raise HTTPException( status_code=403, detail="IP address not authorized", ) + # Store API key model in request state for other endpoints + request.state.api_key_model = api_key_model + return api_key async def get_current_user( + request: Request, api_key: str = Depends(require_api_key), - db: AsyncSession = Depends(get_db), ) -> dict: - """Get current user from API key.""" - # In production, look up user from database - # For now, return mock user + """Get current user from validated API key.""" + # Get API key model from request state (set by require_api_key) + api_key_model = getattr(request.state, 'api_key_model', None) + + if not api_key_model: + # Fallback for anonymous access + return { + "id": "anonymous", + "api_key": api_key, + "role": "anonymous", + "quota": { + "concurrent_jobs": 1, + "monthly_minutes": 100, + }, + } + return { - "id": "user_123", + "id": api_key_model.user_id or f"api_key_{api_key_model.id}", + "api_key_id": str(api_key_model.id), "api_key": api_key, - "role": "user", + "name": api_key_model.name, + "organization": api_key_model.organization, + "role": "admin" if api_key_model.is_admin else "user", "quota": { - "concurrent_jobs": settings.MAX_CONCURRENT_JOBS_PER_KEY, - "monthly_minutes": 10000, + "concurrent_jobs": api_key_model.max_concurrent_jobs, + "monthly_minutes": api_key_model.monthly_limit_minutes, + }, + "usage": { + "total_requests": api_key_model.total_requests, + "last_used_at": api_key_model.last_used_at.isoformat() if api_key_model.last_used_at else None, }, + "expires_at": api_key_model.expires_at.isoformat() if api_key_model.expires_at else None, + "is_admin": api_key_model.is_admin, } \ No newline at end of file diff --git a/api/genai/services/model_manager.py b/api/genai/services/model_manager.py index c02aaed..95dd6cc 100644 --- a/api/genai/services/model_manager.py +++ b/api/genai/services/model_manager.py @@ -251,13 +251,39 @@ async def _load_videomae_model(self, model_name: str, **kwargs) -> Any: raise ImportError(f"VideoMAE dependencies not installed: {e}") async def _load_vmaf_model(self, model_name: str, **kwargs) -> Any: - """Load VMAF model.""" + """Load VMAF model configuration.""" try: import ffmpeg + import os - # VMAF is handled by FFmpeg, so we just return a placeholder - # The actual VMAF computation will be done in the quality predictor - return {"model_version": model_name, "available": True} + # VMAF models are handled by FFmpeg, not loaded into memory + # We validate the model exists and return configuration + vmaf_models = { + "vmaf_v0.6.1": {"version": "v0.6.1", "path": "/usr/local/share/model/vmaf_v0.6.1.json"}, + "vmaf_4k_v0.6.1": {"version": "v0.6.1_4k", "path": "/usr/local/share/model/vmaf_4k_v0.6.1.json"}, + "vmaf_v0.6.0": {"version": "v0.6.0", "path": "/usr/local/share/model/vmaf_v0.6.0.json"}, + } + + model_config = vmaf_models.get(model_name) + if not model_config: + raise ValueError(f"Unknown VMAF model: {model_name}") + + # Check if model file exists (optional, FFmpeg will handle missing models) + model_available = True + if model_config["path"] and os.path.exists(model_config["path"]): + model_available = True + elif model_config["path"]: + # Model file not found, but FFmpeg might have it in different location + logger.warning(f"VMAF model file not found at {model_config['path']}, will use FFmpeg default") + + return { + "model_name": model_name, + "version": model_config["version"], + "path": model_config["path"], + "available": model_available, + "type": "vmaf", + "description": f"VMAF quality assessment model {model_config['version']}", + } except ImportError as e: raise ImportError(f"FFmpeg-python not installed: {e}") diff --git a/api/main.py b/api/main.py index a8e1942..d205b20 100644 --- a/api/main.py +++ b/api/main.py @@ -13,7 +13,7 @@ import structlog from api.config import settings -from api.routers import convert, jobs, admin, health +from api.routers import convert, jobs, admin, health, api_keys from api.utils.logger import setup_logging from api.utils.error_handlers import ( RendiffError, rendiff_exception_handler, validation_exception_handler, @@ -123,6 +123,7 @@ async def lifespan(app: FastAPI): app.include_router(jobs.router, prefix="/api/v1", tags=["jobs"]) app.include_router(admin.router, prefix="/api/v1", tags=["admin"]) app.include_router(health.router, prefix="/api/v1", tags=["health"]) +app.include_router(api_keys.router, prefix="/api/v1", tags=["api-keys"]) # Conditionally include GenAI routers try: diff --git a/api/models/__init__.py b/api/models/__init__.py index e69de29..f74bbd8 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -0,0 +1,5 @@ +from .database import Base, get_session +from .job import Job, JobStatus +from .api_key import APIKey + +__all__ = ["Base", "get_session", "Job", "JobStatus", "APIKey"] \ No newline at end of file diff --git a/api/models/api_key.py b/api/models/api_key.py new file mode 100644 index 0000000..01bd166 --- /dev/null +++ b/api/models/api_key.py @@ -0,0 +1,147 @@ +""" +API Key model for authentication. +""" +import secrets +import hashlib +from datetime import datetime, timedelta +from typing import Optional +from uuid import uuid4 + +from sqlalchemy import Column, String, DateTime, Boolean, Integer, Text +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.sql import func + +from api.models.database import Base + + +class APIKey(Base): + """API Key model for authentication.""" + __tablename__ = "api_keys" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) + name = Column(String(255), nullable=False) + key_hash = Column(String(64), nullable=False, unique=True, index=True) + key_prefix = Column(String(8), nullable=False, index=True) + + # User/organization info + user_id = Column(String(255), nullable=True) + organization = Column(String(255), nullable=True) + + # Permissions and limits + is_active = Column(Boolean, default=True, nullable=False) + is_admin = Column(Boolean, default=False, nullable=False) + max_concurrent_jobs = Column(Integer, default=5, nullable=False) + monthly_limit_minutes = Column(Integer, default=10000, nullable=False) + + # Usage tracking + total_requests = Column(Integer, default=0, nullable=False) + last_used_at = Column(DateTime(timezone=True), nullable=True) + + # Lifecycle + created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + expires_at = Column(DateTime(timezone=True), nullable=True) + revoked_at = Column(DateTime(timezone=True), nullable=True) + + # Metadata + description = Column(Text, nullable=True) + created_by = Column(String(255), nullable=True) + + @classmethod + def generate_key(cls) -> tuple[str, str]: + """ + Generate a new API key. + + Returns: + tuple: (raw_key, key_hash) where raw_key should be shown to user only once + """ + # Generate 32 random bytes (256 bits) + raw_key = secrets.token_urlsafe(32) + + # Create hash for storage + key_hash = hashlib.sha256(raw_key.encode()).hexdigest() + + # Get prefix for indexing (first 8 chars) + key_prefix = raw_key[:8] + + return raw_key, key_hash, key_prefix + + @classmethod + def hash_key(cls, raw_key: str) -> str: + """Hash a raw key for comparison.""" + return hashlib.sha256(raw_key.encode()).hexdigest() + + def is_valid(self) -> bool: + """Check if API key is valid (active, not expired, not revoked).""" + now = datetime.utcnow() + + if not self.is_active: + return False + + if self.revoked_at and self.revoked_at <= now: + return False + + if self.expires_at and self.expires_at <= now: + return False + + return True + + def is_expired(self) -> bool: + """Check if API key is expired.""" + if not self.expires_at: + return False + return datetime.utcnow() > self.expires_at + + def days_until_expiry(self) -> Optional[int]: + """Get days until expiry, or None if no expiry set.""" + if not self.expires_at: + return None + delta = self.expires_at - datetime.utcnow() + return max(0, delta.days) + + def update_last_used(self): + """Update last used timestamp and increment request counter.""" + self.last_used_at = datetime.utcnow() + self.total_requests += 1 + + def revoke(self): + """Revoke this API key.""" + self.revoked_at = datetime.utcnow() + self.is_active = False + + def extend_expiry(self, days: int): + """Extend expiry by specified days.""" + if self.expires_at: + self.expires_at += timedelta(days=days) + else: + self.expires_at = datetime.utcnow() + timedelta(days=days) + + def to_dict(self, include_sensitive: bool = False) -> dict: + """Convert to dictionary for API responses.""" + data = { + "id": str(self.id), + "name": self.name, + "key_prefix": self.key_prefix, + "user_id": self.user_id, + "organization": self.organization, + "is_active": self.is_active, + "is_admin": self.is_admin, + "max_concurrent_jobs": self.max_concurrent_jobs, + "monthly_limit_minutes": self.monthly_limit_minutes, + "total_requests": self.total_requests, + "last_used_at": self.last_used_at.isoformat() if self.last_used_at else None, + "created_at": self.created_at.isoformat(), + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "revoked_at": self.revoked_at.isoformat() if self.revoked_at else None, + "description": self.description, + "created_by": self.created_by, + "is_expired": self.is_expired(), + "days_until_expiry": self.days_until_expiry(), + } + + if include_sensitive: + data["key_hash"] = self.key_hash + + return data + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/api/routers/api_keys.py b/api/routers/api_keys.py new file mode 100644 index 0000000..a18c4a7 --- /dev/null +++ b/api/routers/api_keys.py @@ -0,0 +1,419 @@ +""" +API Keys management endpoints. +""" +from datetime import datetime +from typing import List, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Request, Query +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession +import structlog + +from api.dependencies import get_db, require_api_key, get_current_user +from api.services.api_key import APIKeyService + +router = APIRouter(prefix="/api-keys", tags=["API Keys"]) +logger = structlog.get_logger() + + +class CreateAPIKeyRequest(BaseModel): + """Request model for creating API keys.""" + name: str = Field(..., min_length=1, max_length=255, description="Name for the API key") + description: Optional[str] = Field(None, max_length=1000, description="Description of the API key purpose") + expires_in_days: Optional[int] = Field(None, ge=1, le=3650, description="Number of days until expiry (max 10 years)") + max_concurrent_jobs: int = Field(5, ge=1, le=100, description="Maximum concurrent jobs") + monthly_limit_minutes: int = Field(10000, ge=100, le=1000000, description="Monthly processing limit in minutes") + user_id: Optional[str] = Field(None, max_length=255, description="User ID to associate with this key") + organization: Optional[str] = Field(None, max_length=255, description="Organization name") + + +class CreateAPIKeyResponse(BaseModel): + """Response model for created API keys.""" + id: str + name: str + api_key: str = Field(..., description="The actual API key - save this securely, it won't be shown again") + key_prefix: str + expires_at: Optional[datetime] + max_concurrent_jobs: int + monthly_limit_minutes: int + created_at: datetime + + +class APIKeyInfo(BaseModel): + """API key information (without the actual key).""" + id: str + name: str + key_prefix: str + user_id: Optional[str] + organization: Optional[str] + is_active: bool + is_admin: bool + max_concurrent_jobs: int + monthly_limit_minutes: int + total_requests: int + last_used_at: Optional[datetime] + created_at: datetime + expires_at: Optional[datetime] + revoked_at: Optional[datetime] + description: Optional[str] + created_by: Optional[str] + is_expired: bool + days_until_expiry: Optional[int] + + +class UpdateAPIKeyRequest(BaseModel): + """Request model for updating API keys.""" + name: Optional[str] = Field(None, min_length=1, max_length=255) + description: Optional[str] = Field(None, max_length=1000) + max_concurrent_jobs: Optional[int] = Field(None, ge=1, le=100) + monthly_limit_minutes: Optional[int] = Field(None, ge=100, le=1000000) + is_active: Optional[bool] = None + + +class APIKeyListResponse(BaseModel): + """Response model for listing API keys.""" + api_keys: List[APIKeyInfo] + total_count: int + page: int + page_size: int + has_next: bool + + +@router.post("/", response_model=CreateAPIKeyResponse) +async def create_api_key( + request: CreateAPIKeyRequest, + current_user: dict = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """ + Create a new API key. + + **Note**: The API key will only be displayed once in the response. + Make sure to save it securely. + """ + # Check if user has admin privileges for certain operations + is_admin = current_user.get("is_admin", False) + + # Non-admin users can only create keys for themselves + user_id = request.user_id + if not is_admin and user_id and user_id != current_user.get("id"): + raise HTTPException( + status_code=403, + detail="You can only create API keys for yourself" + ) + + # Default to current user if no user_id specified + if not user_id: + user_id = current_user.get("id") + + try: + api_key_model, raw_key = await APIKeyService.create_api_key( + session=db, + name=request.name, + user_id=user_id, + organization=request.organization, + description=request.description, + expires_in_days=request.expires_in_days, + max_concurrent_jobs=request.max_concurrent_jobs, + monthly_limit_minutes=request.monthly_limit_minutes, + created_by=current_user.get("id"), + ) + + logger.info( + "API key created", + key_id=str(api_key_model.id), + name=request.name, + created_by=current_user.get("id"), + user_id=user_id, + ) + + return CreateAPIKeyResponse( + id=str(api_key_model.id), + name=api_key_model.name, + api_key=raw_key, + key_prefix=api_key_model.key_prefix, + expires_at=api_key_model.expires_at, + max_concurrent_jobs=api_key_model.max_concurrent_jobs, + monthly_limit_minutes=api_key_model.monthly_limit_minutes, + created_at=api_key_model.created_at, + ) + + except Exception as e: + logger.error("Failed to create API key", error=str(e)) + raise HTTPException(status_code=500, detail="Failed to create API key") + + +@router.get("/", response_model=APIKeyListResponse) +async def list_api_keys( + page: int = Query(1, ge=1, description="Page number"), + page_size: int = Query(20, ge=1, le=100, description="Items per page"), + search: Optional[str] = Query(None, description="Search in name, user_id, organization"), + active_only: bool = Query(True, description="Show only active keys"), + current_user: dict = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """List API keys with pagination and filtering.""" + is_admin = current_user.get("is_admin", False) + + offset = (page - 1) * page_size + + try: + if is_admin: + # Admin can see all keys + api_keys, total_count = await APIKeyService.list_api_keys( + session=db, + limit=page_size, + offset=offset, + active_only=active_only, + search=search, + ) + else: + # Regular users can only see their own keys + user_id = current_user.get("id") + if not user_id: + raise HTTPException(status_code=403, detail="Access denied") + + api_keys = await APIKeyService.get_api_keys_for_user( + session=db, + user_id=user_id, + include_revoked=not active_only, + ) + + # Apply search filter if specified + if search: + search_lower = search.lower() + api_keys = [ + key for key in api_keys + if (search_lower in key.name.lower() or + (key.description and search_lower in key.description.lower()) or + (key.organization and search_lower in key.organization.lower())) + ] + + total_count = len(api_keys) + + # Apply pagination + api_keys = api_keys[offset:offset + page_size] + + # Convert to response models + api_key_infos = [APIKeyInfo(**key.to_dict()) for key in api_keys] + + return APIKeyListResponse( + api_keys=api_key_infos, + total_count=total_count, + page=page, + page_size=page_size, + has_next=offset + page_size < total_count, + ) + + except Exception as e: + logger.error("Failed to list API keys", error=str(e)) + raise HTTPException(status_code=500, detail="Failed to list API keys") + + +@router.get("/{key_id}", response_model=APIKeyInfo) +async def get_api_key( + key_id: UUID, + current_user: dict = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Get API key details by ID.""" + is_admin = current_user.get("is_admin", False) + + try: + api_key = await APIKeyService.get_api_key_by_id(db, key_id) + + if not api_key: + raise HTTPException(status_code=404, detail="API key not found") + + # Check permissions + if not is_admin and api_key.user_id != current_user.get("id"): + raise HTTPException(status_code=403, detail="Access denied") + + return APIKeyInfo(**api_key.to_dict()) + + except HTTPException: + raise + except Exception as e: + logger.error("Failed to get API key", error=str(e), key_id=str(key_id)) + raise HTTPException(status_code=500, detail="Failed to get API key") + + +@router.patch("/{key_id}", response_model=APIKeyInfo) +async def update_api_key( + key_id: UUID, + request: UpdateAPIKeyRequest, + current_user: dict = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Update API key settings.""" + is_admin = current_user.get("is_admin", False) + + try: + # Get existing key + api_key = await APIKeyService.get_api_key_by_id(db, key_id) + + if not api_key: + raise HTTPException(status_code=404, detail="API key not found") + + # Check permissions + if not is_admin and api_key.user_id != current_user.get("id"): + raise HTTPException(status_code=403, detail="Access denied") + + # Prepare updates + updates = {} + if request.name is not None: + updates["name"] = request.name + if request.description is not None: + updates["description"] = request.description + if request.max_concurrent_jobs is not None: + updates["max_concurrent_jobs"] = request.max_concurrent_jobs + if request.monthly_limit_minutes is not None: + updates["monthly_limit_minutes"] = request.monthly_limit_minutes + if request.is_active is not None: + updates["is_active"] = request.is_active + + if not updates: + # No changes requested + return APIKeyInfo(**api_key.to_dict()) + + # Update the key + updated_key = await APIKeyService.update_api_key(db, key_id, updates) + + logger.info( + "API key updated", + key_id=str(key_id), + updates=updates, + updated_by=current_user.get("id"), + ) + + return APIKeyInfo(**updated_key.to_dict()) + + except HTTPException: + raise + except Exception as e: + logger.error("Failed to update API key", error=str(e), key_id=str(key_id)) + raise HTTPException(status_code=500, detail="Failed to update API key") + + +@router.post("/{key_id}/revoke", response_model=APIKeyInfo) +async def revoke_api_key( + key_id: UUID, + current_user: dict = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Revoke an API key (permanently disable it).""" + is_admin = current_user.get("is_admin", False) + + try: + # Get existing key + api_key = await APIKeyService.get_api_key_by_id(db, key_id) + + if not api_key: + raise HTTPException(status_code=404, detail="API key not found") + + # Check permissions + if not is_admin and api_key.user_id != current_user.get("id"): + raise HTTPException(status_code=403, detail="Access denied") + + if api_key.revoked_at: + raise HTTPException(status_code=400, detail="API key is already revoked") + + # Revoke the key + revoked_key = await APIKeyService.revoke_api_key( + db, key_id, revoked_by=current_user.get("id") + ) + + logger.info( + "API key revoked", + key_id=str(key_id), + revoked_by=current_user.get("id"), + ) + + return APIKeyInfo(**revoked_key.to_dict()) + + except HTTPException: + raise + except Exception as e: + logger.error("Failed to revoke API key", error=str(e), key_id=str(key_id)) + raise HTTPException(status_code=500, detail="Failed to revoke API key") + + +@router.post("/{key_id}/extend", response_model=APIKeyInfo) +async def extend_api_key_expiry( + key_id: UUID, + additional_days: int = Query(..., ge=1, le=3650, description="Days to extend expiry"), + current_user: dict = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Extend API key expiry date.""" + is_admin = current_user.get("is_admin", False) + + try: + # Get existing key + api_key = await APIKeyService.get_api_key_by_id(db, key_id) + + if not api_key: + raise HTTPException(status_code=404, detail="API key not found") + + # Check permissions + if not is_admin and api_key.user_id != current_user.get("id"): + raise HTTPException(status_code=403, detail="Access denied") + + if api_key.revoked_at: + raise HTTPException(status_code=400, detail="Cannot extend revoked API key") + + # Extend the key + extended_key = await APIKeyService.extend_api_key_expiry( + db, key_id, additional_days + ) + + logger.info( + "API key expiry extended", + key_id=str(key_id), + additional_days=additional_days, + extended_by=current_user.get("id"), + ) + + return APIKeyInfo(**extended_key.to_dict()) + + except HTTPException: + raise + except Exception as e: + logger.error("Failed to extend API key", error=str(e), key_id=str(key_id)) + raise HTTPException(status_code=500, detail="Failed to extend API key") + + +@router.get("/{key_id}/usage", response_model=dict) +async def get_api_key_usage( + key_id: UUID, + days: int = Query(30, ge=1, le=365, description="Number of days to analyze"), + current_user: dict = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Get usage statistics for an API key.""" + is_admin = current_user.get("is_admin", False) + + try: + # Get existing key + api_key = await APIKeyService.get_api_key_by_id(db, key_id) + + if not api_key: + raise HTTPException(status_code=404, detail="API key not found") + + # Check permissions + if not is_admin and api_key.user_id != current_user.get("id"): + raise HTTPException(status_code=403, detail="Access denied") + + # Get usage stats + usage_stats = await APIKeyService.get_usage_stats( + db, key_id=key_id, days=days + ) + + return usage_stats + + except HTTPException: + raise + except Exception as e: + logger.error("Failed to get usage stats", error=str(e), key_id=str(key_id)) + raise HTTPException(status_code=500, detail="Failed to get usage statistics") \ No newline at end of file diff --git a/api/routers/jobs.py b/api/routers/jobs.py index 4a6651f..e9b49ec 100644 --- a/api/routers/jobs.py +++ b/api/routers/jobs.py @@ -324,17 +324,33 @@ async def get_job_logs( # Get live logs from worker logs = await queue_service.get_worker_logs(job.worker_id, str(job_id), lines) else: - # Get stored logs - # This is a placeholder - implement actual log storage - logs = [ - f"Job {job_id} - Status: {job.status}", - f"Created: {job.created_at}", - f"Started: {job.started_at}", - f"Completed: {job.completed_at}", - ] + # Get stored logs from database and log aggregation system + from api.services.job_service import JobService - if job.error_message: - logs.append(f"Error: {job.error_message}") + stored_logs = await JobService.get_job_logs(db, job_id, lines) + + if stored_logs: + logs = stored_logs + else: + # Fallback to basic job information if no detailed logs available + logs = [ + f"[{job.created_at.isoformat()}] Job created: {job_id}", + f"[{job.created_at.isoformat()}] Status: {job.status.value}", + f"[{job.created_at.isoformat()}] Input: {job.input_url or 'N/A'}", + f"[{job.created_at.isoformat()}] Output: {job.output_url or 'N/A'}", + ] + + if job.started_at: + logs.append(f"[{job.started_at.isoformat()}] Processing started") + + if job.completed_at: + logs.append(f"[{job.completed_at.isoformat()}] Processing completed") + + if job.error_message: + logs.append(f"[{(job.completed_at or job.started_at or job.created_at).isoformat()}] ERROR: {job.error_message}") + + if job.progress > 0: + logs.append(f"[{(job.completed_at or job.started_at or job.created_at).isoformat()}] Progress: {job.progress}%") return { "job_id": str(job_id), diff --git a/api/services/api_key.py b/api/services/api_key.py new file mode 100644 index 0000000..d1971c9 --- /dev/null +++ b/api/services/api_key.py @@ -0,0 +1,353 @@ +""" +API Key service for authentication and key management. +""" +from datetime import datetime, timedelta +from typing import Optional, List, Dict, Any +from uuid import UUID + +from sqlalchemy import select, func, and_, or_ +from sqlalchemy.ext.asyncio import AsyncSession +import structlog + +from api.models.api_key import APIKey + +logger = structlog.get_logger() + + +class APIKeyService: + """Service for managing API keys.""" + + @staticmethod + async def create_api_key( + session: AsyncSession, + name: str, + user_id: Optional[str] = None, + organization: Optional[str] = None, + description: Optional[str] = None, + expires_in_days: Optional[int] = None, + max_concurrent_jobs: int = 5, + monthly_limit_minutes: int = 10000, + is_admin: bool = False, + created_by: Optional[str] = None, + ) -> tuple[APIKey, str]: + """ + Create a new API key. + + Returns: + tuple: (api_key_model, raw_key) - raw_key should be shown to user only once + """ + # Generate key + raw_key, key_hash, key_prefix = APIKey.generate_key() + + # Calculate expiry + expires_at = None + if expires_in_days: + expires_at = datetime.utcnow() + timedelta(days=expires_in_days) + + # Create API key model + api_key = APIKey( + name=name, + key_hash=key_hash, + key_prefix=key_prefix, + user_id=user_id, + organization=organization, + description=description, + expires_at=expires_at, + max_concurrent_jobs=max_concurrent_jobs, + monthly_limit_minutes=monthly_limit_minutes, + is_admin=is_admin, + created_by=created_by, + ) + + session.add(api_key) + await session.commit() + await session.refresh(api_key) + + logger.info( + "API key created", + key_id=str(api_key.id), + name=name, + user_id=user_id, + organization=organization, + expires_at=expires_at, + ) + + return api_key, raw_key + + @staticmethod + async def validate_api_key( + session: AsyncSession, + raw_key: str, + update_usage: bool = True, + ) -> Optional[APIKey]: + """ + Validate an API key and optionally update usage stats. + + Args: + session: Database session + raw_key: The raw API key to validate + update_usage: Whether to update last_used_at and request count + + Returns: + APIKey model if valid, None if invalid + """ + if not raw_key or not raw_key.strip(): + return None + + # Hash the key for lookup + key_hash = APIKey.hash_key(raw_key) + + # Find API key by hash + stmt = select(APIKey).where(APIKey.key_hash == key_hash) + result = await session.execute(stmt) + api_key = result.scalar_one_or_none() + + if not api_key: + logger.warning("API key not found", key_prefix=raw_key[:8]) + return None + + # Check if key is valid + if not api_key.is_valid(): + logger.warning( + "Invalid API key used", + key_id=str(api_key.id), + is_active=api_key.is_active, + is_expired=api_key.is_expired(), + revoked_at=api_key.revoked_at, + ) + return None + + # Update usage if requested + if update_usage: + api_key.update_last_used() + await session.commit() + + logger.info( + "API key validated successfully", + key_id=str(api_key.id), + name=api_key.name, + user_id=api_key.user_id, + ) + + return api_key + + @staticmethod + async def get_api_key_by_id( + session: AsyncSession, + key_id: UUID, + ) -> Optional[APIKey]: + """Get API key by ID.""" + stmt = select(APIKey).where(APIKey.id == key_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + @staticmethod + async def get_api_keys_for_user( + session: AsyncSession, + user_id: str, + include_revoked: bool = False, + ) -> List[APIKey]: + """Get all API keys for a user.""" + stmt = select(APIKey).where(APIKey.user_id == user_id) + + if not include_revoked: + stmt = stmt.where(APIKey.revoked_at.is_(None)) + + stmt = stmt.order_by(APIKey.created_at.desc()) + + result = await session.execute(stmt) + return list(result.scalars().all()) + + @staticmethod + async def get_api_keys_for_organization( + session: AsyncSession, + organization: str, + include_revoked: bool = False, + ) -> List[APIKey]: + """Get all API keys for an organization.""" + stmt = select(APIKey).where(APIKey.organization == organization) + + if not include_revoked: + stmt = stmt.where(APIKey.revoked_at.is_(None)) + + stmt = stmt.order_by(APIKey.created_at.desc()) + + result = await session.execute(stmt) + return list(result.scalars().all()) + + @staticmethod + async def list_api_keys( + session: AsyncSession, + limit: int = 100, + offset: int = 0, + active_only: bool = True, + search: Optional[str] = None, + ) -> tuple[List[APIKey], int]: + """ + List API keys with pagination and filtering. + + Returns: + tuple: (api_keys, total_count) + """ + # Build base query + stmt = select(APIKey) + count_stmt = select(func.count(APIKey.id)) + + # Apply filters + if active_only: + stmt = stmt.where( + and_( + APIKey.is_active == True, + APIKey.revoked_at.is_(None) + ) + ) + count_stmt = count_stmt.where( + and_( + APIKey.is_active == True, + APIKey.revoked_at.is_(None) + ) + ) + + if search: + search_filter = or_( + APIKey.name.ilike(f"%{search}%"), + APIKey.user_id.ilike(f"%{search}%"), + APIKey.organization.ilike(f"%{search}%"), + APIKey.description.ilike(f"%{search}%"), + ) + stmt = stmt.where(search_filter) + count_stmt = count_stmt.where(search_filter) + + # Apply pagination + stmt = stmt.order_by(APIKey.created_at.desc()).limit(limit).offset(offset) + + # Execute queries + result = await session.execute(stmt) + count_result = await session.execute(count_stmt) + + api_keys = list(result.scalars().all()) + total_count = count_result.scalar() + + return api_keys, total_count + + @staticmethod + async def revoke_api_key( + session: AsyncSession, + key_id: UUID, + revoked_by: Optional[str] = None, + ) -> Optional[APIKey]: + """Revoke an API key.""" + api_key = await APIKeyService.get_api_key_by_id(session, key_id) + + if not api_key: + return None + + if api_key.revoked_at: + return api_key # Already revoked + + api_key.revoke() + await session.commit() + + logger.info( + "API key revoked", + key_id=str(api_key.id), + name=api_key.name, + revoked_by=revoked_by, + ) + + return api_key + + @staticmethod + async def extend_api_key_expiry( + session: AsyncSession, + key_id: UUID, + additional_days: int, + ) -> Optional[APIKey]: + """Extend API key expiry.""" + api_key = await APIKeyService.get_api_key_by_id(session, key_id) + + if not api_key: + return None + + old_expiry = api_key.expires_at + api_key.extend_expiry(additional_days) + await session.commit() + + logger.info( + "API key expiry extended", + key_id=str(api_key.id), + name=api_key.name, + old_expiry=old_expiry, + new_expiry=api_key.expires_at, + additional_days=additional_days, + ) + + return api_key + + @staticmethod + async def update_api_key( + session: AsyncSession, + key_id: UUID, + updates: Dict[str, Any], + ) -> Optional[APIKey]: + """Update API key properties.""" + api_key = await APIKeyService.get_api_key_by_id(session, key_id) + + if not api_key: + return None + + # Apply updates + allowed_fields = { + "name", "description", "max_concurrent_jobs", + "monthly_limit_minutes", "is_active" + } + + for field, value in updates.items(): + if field in allowed_fields and hasattr(api_key, field): + setattr(api_key, field, value) + + await session.commit() + + logger.info( + "API key updated", + key_id=str(api_key.id), + name=api_key.name, + updates=updates, + ) + + return api_key + + @staticmethod + async def get_usage_stats( + session: AsyncSession, + key_id: Optional[UUID] = None, + user_id: Optional[str] = None, + organization: Optional[str] = None, + days: int = 30, + ) -> Dict[str, Any]: + """Get usage statistics for API keys.""" + # This would typically query a separate usage/metrics table + # For now, return basic stats from the API key table + + stmt = select(APIKey) + + if key_id: + stmt = stmt.where(APIKey.id == key_id) + elif user_id: + stmt = stmt.where(APIKey.user_id == user_id) + elif organization: + stmt = stmt.where(APIKey.organization == organization) + + result = await session.execute(stmt) + api_keys = list(result.scalars().all()) + + total_requests = sum(key.total_requests for key in api_keys) + active_keys = sum(1 for key in api_keys if key.is_valid()) + + return { + "total_keys": len(api_keys), + "active_keys": active_keys, + "total_requests": total_requests, + "period_days": days, + "api_keys": [key.to_dict() for key in api_keys], + } \ No newline at end of file diff --git a/api/services/job_service.py b/api/services/job_service.py new file mode 100644 index 0000000..dcad01b --- /dev/null +++ b/api/services/job_service.py @@ -0,0 +1,219 @@ +""" +Job service for managing job operations. +""" +from datetime import datetime, timedelta +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy import select, func, and_, or_, desc +from sqlalchemy.ext.asyncio import AsyncSession +import structlog + +from api.models.job import Job, JobStatus + +logger = structlog.get_logger() + + +class JobService: + """Service for managing jobs.""" + + @staticmethod + async def get_job_logs( + session: AsyncSession, + job_id: UUID, + lines: int = 100, + ) -> List[str]: + """ + Get stored logs for a job. + + In a production system, this would query a log aggregation service + like ELK stack, but for now we return structured logs from job data. + """ + # Get the job + stmt = select(Job).where(Job.id == job_id) + result = await session.execute(stmt) + job = result.scalar_one_or_none() + + if not job: + return [] + + # Build log entries from job lifecycle + logs = [] + + # Job creation + logs.append(f"[{job.created_at.isoformat()}] Job created: {job_id}") + logs.append(f"[{job.created_at.isoformat()}] Status: QUEUED") + logs.append(f"[{job.created_at.isoformat()}] Input URL: {job.input_url}") + logs.append(f"[{job.created_at.isoformat()}] Operations: {len(job.operations)} operations requested") + + # Job parameters + if job.options: + logs.append(f"[{job.created_at.isoformat()}] Options: {job.options}") + + # Processing start + if job.started_at: + logs.append(f"[{job.started_at.isoformat()}] Status: PROCESSING") + logs.append(f"[{job.started_at.isoformat()}] Worker ID: {job.worker_id}") + logs.append(f"[{job.started_at.isoformat()}] Processing started") + + # Progress updates (simulated based on current progress) + if job.progress > 0 and job.started_at: + # Add some progress log entries + progress_steps = [10, 25, 50, 75, 90] + for step in progress_steps: + if job.progress >= step: + # Estimate timestamp based on progress + if job.completed_at: + # Job is complete, interpolate timestamps + total_duration = (job.completed_at - job.started_at).total_seconds() + step_duration = total_duration * (step / 100) + step_time = job.started_at + timedelta(seconds=step_duration) + else: + # Job still running, use current time for latest progress + if step == max([s for s in progress_steps if job.progress >= s]): + step_time = datetime.utcnow() + else: + # Estimate based on linear progress + elapsed = (datetime.utcnow() - job.started_at).total_seconds() + step_duration = elapsed * (step / job.progress) if job.progress > 0 else elapsed + step_time = job.started_at + timedelta(seconds=step_duration) + + logs.append(f"[{step_time.isoformat()}] Progress: {step}% complete") + + # Job completion + if job.completed_at: + if job.status == JobStatus.COMPLETED: + logs.append(f"[{job.completed_at.isoformat()}] Status: COMPLETED") + logs.append(f"[{job.completed_at.isoformat()}] Output URL: {job.output_url}") + logs.append(f"[{job.completed_at.isoformat()}] Processing completed successfully") + + # Calculate processing time + if job.started_at: + duration = (job.completed_at - job.started_at).total_seconds() + logs.append(f"[{job.completed_at.isoformat()}] Total processing time: {duration:.2f} seconds") + + elif job.status == JobStatus.FAILED: + logs.append(f"[{job.completed_at.isoformat()}] Status: FAILED") + logs.append(f"[{job.completed_at.isoformat()}] Error: {job.error_message}") + + elif job.status == JobStatus.CANCELLED: + logs.append(f"[{job.completed_at.isoformat()}] Status: CANCELLED") + logs.append(f"[{job.completed_at.isoformat()}] Job was cancelled") + + # Webhook notifications + if job.webhook_url and job.status in [JobStatus.COMPLETED, JobStatus.FAILED]: + webhook_time = job.completed_at or datetime.utcnow() + logs.append(f"[{webhook_time.isoformat()}] Webhook notification sent to: {job.webhook_url}") + + # Return the requested number of lines (most recent first) + return logs[-lines:] if lines > 0 else logs + + @staticmethod + async def get_job_by_id( + session: AsyncSession, + job_id: UUID, + api_key: Optional[str] = None, + ) -> Optional[Job]: + """Get job by ID, optionally filtered by API key.""" + stmt = select(Job).where(Job.id == job_id) + + if api_key: + stmt = stmt.where(Job.api_key == api_key) + + result = await session.execute(stmt) + return result.scalar_one_or_none() + + @staticmethod + async def get_jobs_for_api_key( + session: AsyncSession, + api_key: str, + status: Optional[JobStatus] = None, + limit: int = 100, + offset: int = 0, + ) -> tuple[List[Job], int]: + """Get jobs for an API key with pagination.""" + # Build base query + stmt = select(Job).where(Job.api_key == api_key) + count_stmt = select(func.count(Job.id)).where(Job.api_key == api_key) + + # Apply status filter + if status: + stmt = stmt.where(Job.status == status) + count_stmt = count_stmt.where(Job.status == status) + + # Apply pagination + stmt = stmt.order_by(desc(Job.created_at)).limit(limit).offset(offset) + + # Execute queries + result = await session.execute(stmt) + count_result = await session.execute(count_stmt) + + jobs = list(result.scalars().all()) + total_count = count_result.scalar() + + return jobs, total_count + + @staticmethod + async def get_job_statistics( + session: AsyncSession, + api_key: Optional[str] = None, + days: int = 30, + ) -> Dict[str, Any]: + """Get job statistics.""" + from datetime import timedelta + + # Calculate date range + end_date = datetime.utcnow() + start_date = end_date - timedelta(days=days) + + # Build base query + base_stmt = select(Job).where(Job.created_at >= start_date) + + if api_key: + base_stmt = base_stmt.where(Job.api_key == api_key) + + # Get total count + count_stmt = select(func.count(Job.id)).where(Job.created_at >= start_date) + if api_key: + count_stmt = count_stmt.where(Job.api_key == api_key) + + total_result = await session.execute(count_stmt) + total_jobs = total_result.scalar() + + # Get status counts + status_stats = {} + for status in JobStatus: + status_stmt = count_stmt.where(Job.status == status) + status_result = await session.execute(status_stmt) + status_stats[status.value] = status_result.scalar() + + # Get average processing time for completed jobs + completed_stmt = select( + func.avg( + func.extract('epoch', Job.completed_at - Job.started_at) + ) + ).where( + and_( + Job.status == JobStatus.COMPLETED, + Job.started_at.isnot(None), + Job.completed_at.isnot(None), + Job.created_at >= start_date + ) + ) + + if api_key: + completed_stmt = completed_stmt.where(Job.api_key == api_key) + + avg_result = await session.execute(completed_stmt) + avg_processing_time = avg_result.scalar() or 0 + + return { + "period_days": days, + "total_jobs": total_jobs, + "status_breakdown": status_stats, + "average_processing_time_seconds": float(avg_processing_time), + "success_rate": ( + status_stats.get("completed", 0) / total_jobs * 100 + if total_jobs > 0 else 0 + ), + } \ No newline at end of file