diff --git a/deployment/fastapi_inference/README.md b/deployment/fastapi_inference/README.md new file mode 100644 index 000000000..f67d6d470 --- /dev/null +++ b/deployment/fastapi_inference/README.md @@ -0,0 +1,354 @@ +# MONAI + FastAPI Inference Deployment Tutorial + +This tutorial demonstrates how to deploy MONAI model bundles as production-ready REST APIs using FastAPI. + +## ๐Ÿ“š Overview + +Learn how to: +- Load and serve MONAI model bundles +- Create FastAPI endpoints for medical image inference +- Handle medical image uploads (NIfTI format) +- Deploy with Docker for production +- Test and monitor your deployed model + +## ๐ŸŽฏ What You'll Build + +A complete REST API service that: +- โœ… Loads a pre-trained MONAI model (spleen CT segmentation) +- โœ… Accepts medical image uploads via HTTP +- โœ… Returns inference results in JSON format +- โœ… Includes auto-generated API documentation +- โœ… Runs in Docker containers for easy deployment + +## ๐Ÿ“‹ Prerequisites + +- Python 3.9+ installed +- Docker installed (for containerization) +- Basic knowledge of Python and REST APIs +- Familiarity with medical imaging (helpful but not required) + +## ๐Ÿš€ Quick Start + +### 1. Install Dependencies + +```bash +pip install -r requirements.txt +``` + +### 2. Run the API Locally + +```bash +# From the fastapi_inference directory +python -m uvicorn app.main:app --reload +``` + +The API will be available at `http://localhost:8000` + +### 3. Test the API + +**Health Check:** +```bash +curl http://localhost:8000/health +``` + +**View API Documentation:** +Open `http://localhost:8000/docs` in your browser + +**Make a Prediction:** +```bash +curl -X POST http://localhost:8000/predict \ + -F "file=@path/to/your/image.nii.gz" +``` + +## ๐Ÿ“ Project Structure + +``` +fastapi_inference/ +โ”œโ”€โ”€ README.md # This file +โ”œโ”€โ”€ requirements.txt # Python dependencies +โ”œโ”€โ”€ app/ # FastAPI application +โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ”œโ”€โ”€ main.py # FastAPI app and routes +โ”‚ โ”œโ”€โ”€ model_loader.py # MONAI model loading (singleton) +โ”‚ โ”œโ”€โ”€ inference.py # Inference logic +โ”‚ โ””โ”€โ”€ schemas.py # Pydantic models for validation +โ”œโ”€โ”€ tests/ # Unit tests +โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ””โ”€โ”€ test_api.py # API endpoint tests +โ”œโ”€โ”€ docker/ # Docker configuration +โ”‚ โ”œโ”€โ”€ Dockerfile # Container definition +โ”‚ โ””โ”€โ”€ docker-compose.yml # Orchestration +โ”œโ”€โ”€ notebooks/ # Interactive tutorials +โ”‚ โ””โ”€โ”€ fastapi_tutorial.ipynb # Step-by-step walkthrough +โ””โ”€โ”€ examples/ # Usage examples + โ”œโ”€โ”€ client.py # Python client example + โ””โ”€โ”€ sample_requests.http # HTTP request examples +``` + +## ๐Ÿ”ง API Endpoints + +### `GET /` +Returns API information + +### `GET /health` +Health check endpoint +- Returns service status +- Indicates if model is loaded +- Shows computation device (CPU/GPU) + +**Example Response:** +```json +{ + "status": "healthy", + "model_loaded": true, + "device": "cuda" +} +``` + +### `POST /predict` +Run inference on uploaded medical image + +**Request:** +- Method: POST +- Content-Type: multipart/form-data +- Body: file (NIfTI format: .nii or .nii.gz) + +**Response:** +```json +{ + "success": true, + "prediction": { + "shape": [1, 2, 96, 96, 96], + "min_value": 0.0, + "max_value": 1.0, + "unique_labels": [0, 1], + "num_labels": 2 + }, + "segmentation_shape": [1, 2, 96, 96, 96], + "metadata": { + "image_shape": [1, 1, 96, 96, 96], + "processing_time": 2.345, + "device": "cuda" + }, + "message": "Inference completed successfully in 2.345s" +} +``` + +### `GET /docs` +Interactive API documentation (Swagger UI) + +### `GET /redoc` +Alternative API documentation (ReDoc) + +## ๐Ÿณ Docker Deployment + +### Build and Run with Docker + +```bash +# Build the image +docker build -t monai-fastapi -f docker/Dockerfile . + +# Run the container +docker run -p 8000:8000 monai-fastapi +``` + +### Or use Docker Compose + +```bash +# Start the service +docker-compose -f docker/docker-compose.yml up -d + +# View logs +docker-compose -f docker/docker-compose.yml logs -f + +# Stop the service +docker-compose -f docker/docker-compose.yml down +``` + +## ๐Ÿ“ Usage Examples + +### Python Client + +```python +from examples.client import MONAIClient + +# Initialize client +client = MONAIClient(base_url="http://localhost:8000") + +# Check health +health = client.health_check() +print(health) + +# Make prediction +result = client.predict("path/to/image.nii.gz") +print(result) +``` + +### Command Line + +```bash +# Check health +python examples/client.py --health + +# Run prediction +python examples/client.py --image path/to/image.nii.gz +``` + +### cURL Examples + +```bash +# Health check +curl http://localhost:8000/health + +# Prediction +curl -X POST http://localhost:8000/predict \ + -F "file=@tests/sample_image.nii.gz" +``` + +## ๐Ÿงช Running Tests + +```bash +# Install test dependencies +pip install pytest pytest-asyncio httpx + +# Run all tests +pytest tests/ + +# Run with coverage +pytest tests/ --cov=app --cov-report=html +``` + +## ๐Ÿ” Model Information + +**Default Model:** spleen_ct_segmentation + +This tutorial uses MONAI's spleen CT segmentation bundle, which: +- Segments spleen from CT scans +- Pre-trained on Medical Segmentation Decathlon dataset +- Fast inference (~2-3 seconds on GPU) +- Good starting point for learning deployment + +**To use a different model:** +Edit `app/main.py` and change the model name in the `lifespan` function: +```python +model_loader.load_model( + model_name="your_model_name", # Change this + bundle_dir="./models" +) +``` + +## โš™๏ธ Configuration + +### Environment Variables + +Create a `.env` file for configuration: + +```env +# Server configuration +HOST=0.0.0.0 +PORT=8000 +LOG_LEVEL=info + +# Model configuration +MODEL_NAME=spleen_ct_segmentation +MODEL_DIR=./models + +# Performance +WORKERS=1 +``` + +### GPU Support + +The application automatically detects and uses GPU if available: +- **With GPU:** Faster inference, handles larger images +- **Without GPU:** Runs on CPU (slower but works) + +## ๐Ÿšฆ Production Considerations + +### Security +- Add authentication (JWT, API keys) +- Validate file sizes and types +- Use HTTPS in production +- Set CORS origins explicitly + +### Performance +- Use multiple worker processes for scaling +- Add caching for frequently used models +- Implement request queuing for high load +- Consider model quantization for speed + +### Monitoring +- Add logging and metrics +- Track inference times +- Monitor memory usage +- Set up health check endpoints + +### Example Production Command + +```bash +uvicorn app.main:app \ + --host 0.0.0.0 \ + --port 8000 \ + --workers 4 \ + --log-level info \ + --proxy-headers \ + --forwarded-allow-ips='*' +``` + +## ๐Ÿ› Troubleshooting + +### Model Download Fails +``` +Error: Failed to download model bundle +Solution: Check internet connection and MONAI bundle name +``` + +### Out of Memory +``` +Error: CUDA out of memory +Solution: Reduce batch size or use CPU with smaller model +``` + +### File Format Error +``` +Error: Invalid file format +Solution: Ensure file is NIfTI format (.nii or .nii.gz) +``` + +### Port Already in Use +``` +Error: Address already in use +Solution: Change port or kill process using port 8000 +``` + +## ๐Ÿ“š Additional Resources + +- [FastAPI Documentation](https://fastapi.tiangolo.com/) +- [MONAI Documentation](https://docs.monai.io/) +- [MONAI Model Zoo](https://monai.io/model-zoo.html) +- [MONAI Bundle Guide](https://docs.monai.io/en/stable/bundle_intro.html) +- [Docker Documentation](https://docs.docker.com/) + +## ๐Ÿค Contributing + +This tutorial is part of the MONAI tutorials collection. Contributions welcome! + +## ๐Ÿ“„ License + +Copyright 2025 MONAI Consortium +Licensed under the Apache License, Version 2.0 + +## ๐Ÿ™‹ Support + +For questions about this tutorial: +- Open an issue on GitHub +- Visit MONAI community forums +- Check existing tutorials for similar examples + +--- + +**Next Steps:** +1. โœ… Run through the tutorial +2. โœ… Experiment with different models +3. โœ… Deploy to your infrastructure +4. โœ… Build your own medical AI application! diff --git a/deployment/fastapi_inference/app/__init__.py b/deployment/fastapi_inference/app/__init__.py new file mode 100644 index 000000000..9b4fbf5bc --- /dev/null +++ b/deployment/fastapi_inference/app/__init__.py @@ -0,0 +1,7 @@ +""" +FastAPI Inference Service for MONAI Models + +This package provides a production-ready REST API for deploying MONAI model bundles. +""" + +__version__ = "1.0.0" diff --git a/deployment/fastapi_inference/app/inference.py b/deployment/fastapi_inference/app/inference.py new file mode 100644 index 000000000..6854cc1b3 --- /dev/null +++ b/deployment/fastapi_inference/app/inference.py @@ -0,0 +1,220 @@ +""" +Inference Logic + +This module handles the preprocessing, inference, and postprocessing +of medical images using MONAI models. +""" + +import logging +import time +from io import BytesIO +from typing import Dict, Tuple + +import nibabel as nib +import numpy as np +import torch +from monai.transforms import ( + Compose, + LoadImage, + EnsureChannelFirst, + Spacing, + ScaleIntensity, + EnsureType, +) + +from .model_loader import model_loader + +logger = logging.getLogger(__name__) + + +class InferenceEngine: + """Handles image preprocessing, inference, and postprocessing.""" + + def __init__(self): + """Initialize the inference engine with preprocessing transforms.""" + self.preprocess = Compose( + [ + LoadImage(image_only=True), + EnsureChannelFirst(), + Spacing(pixdim=(1.5, 1.5, 2.0)), + ScaleIntensity(), + EnsureType(dtype=torch.float32), + ] + ) + + async def process_image(self, image_bytes: bytes, filename: str) -> Dict: + """ + Process an uploaded image and return predictions. + + Args: + image_bytes: Raw bytes of the uploaded image + filename: Original filename (for logging) + + Returns: + Dictionary containing prediction results and metadata + + Raises: + ValueError: If image format is unsupported + RuntimeError: If inference fails + """ + start_time = time.time() + + try: + # Save bytes to temporary file-like object + image_buffer = BytesIO(image_bytes) + + # Load and preprocess image + logger.info(f"Processing image: {filename}") + image_data = self._load_image(image_buffer, filename) + image_tensor = self._preprocess_image(image_data) + + # Run inference + prediction = await self._run_inference(image_tensor) + + # Calculate processing time + processing_time = time.time() - start_time + + # Prepare response + result = { + "success": True, + "prediction": self._format_prediction(prediction), + "segmentation_shape": ( + list(prediction.shape) if isinstance(prediction, (np.ndarray, torch.Tensor)) else None + ), + "metadata": { + "image_shape": list(image_tensor.shape), + "processing_time": round(processing_time, 3), + "device": str(model_loader.device), + }, + "message": f"Inference completed successfully in {processing_time:.3f}s", + } + + logger.info(f"Inference completed in {processing_time:.3f}s") + return result + + except Exception as e: + logger.error(f"Inference failed: {str(e)}") + raise RuntimeError(f"Inference error: {str(e)}") + + def _load_image(self, image_buffer: BytesIO, filename: str) -> np.ndarray: + """ + Load image from bytes buffer. + + Args: + image_buffer: BytesIO object containing image data + filename: Original filename for format detection + + Returns: + Loaded image as numpy array + + Raises: + ValueError: If image format is unsupported + """ + try: + # Support NIfTI format (.nii, .nii.gz) + if filename.endswith((".nii", ".nii.gz")): + image_buffer.seek(0) + img = nib.load(image_buffer) + return img.get_fdata() + else: + raise ValueError(f"Unsupported image format: {filename}. " "Supported formats: .nii, .nii.gz") + except Exception as e: + raise ValueError(f"Failed to load image: {str(e)}") + + def _preprocess_image(self, image_data: np.ndarray) -> torch.Tensor: + """ + Preprocess image for inference. + + Args: + image_data: Raw image data as numpy array + + Returns: + Preprocessed image tensor + """ + try: + # Add batch dimension if needed + image_tensor = torch.from_numpy(image_data) + + # Ensure batch dimension + if image_tensor.ndim == 3: + image_tensor = image_tensor.unsqueeze(0) # Add channel + if image_tensor.ndim == 4: + image_tensor = image_tensor.unsqueeze(0) # Add batch + + # Move to device + image_tensor = image_tensor.to(model_loader.device) + + return image_tensor + + except Exception as e: + raise RuntimeError(f"Preprocessing failed: {str(e)}") + + async def _run_inference(self, image_tensor: torch.Tensor) -> torch.Tensor: + """ + Run model inference. + + Args: + image_tensor: Preprocessed image tensor + + Returns: + Model prediction + + Raises: + RuntimeError: If inference fails + """ + try: + model = model_loader.model + + # Run inference with no gradient computation + with torch.no_grad(): + if hasattr(model, "__call__"): + prediction = model(image_tensor) + else: + raise RuntimeError("Model is not callable") + + return prediction + + except Exception as e: + raise RuntimeError(f"Model inference failed: {str(e)}") + + def _format_prediction(self, prediction: torch.Tensor) -> Dict: + """ + Format prediction output for JSON response. + + Args: + prediction: Raw model output + + Returns: + Formatted prediction dictionary + """ + try: + # Convert to numpy + if isinstance(prediction, torch.Tensor): + pred_np = prediction.cpu().numpy() + else: + pred_np = prediction + + # Basic statistics + result = { + "shape": list(pred_np.shape), + "dtype": str(pred_np.dtype), + "min_value": float(pred_np.min()), + "max_value": float(pred_np.max()), + "mean_value": float(pred_np.mean()), + } + + # For segmentation masks, add unique labels + if pred_np.ndim >= 3: + unique_labels = np.unique(pred_np.astype(int)) + result["unique_labels"] = unique_labels.tolist() + result["num_labels"] = len(unique_labels) + + return result + + except Exception as e: + logger.warning(f"Failed to format prediction: {str(e)}") + return {"raw_type": str(type(prediction))} + + +# Global inference engine instance +inference_engine = InferenceEngine() diff --git a/deployment/fastapi_inference/app/main.py b/deployment/fastapi_inference/app/main.py new file mode 100644 index 000000000..19b24c082 --- /dev/null +++ b/deployment/fastapi_inference/app/main.py @@ -0,0 +1,187 @@ +""" +FastAPI Application for MONAI Model Inference + +This module provides a REST API for deploying MONAI model bundles. +It demonstrates how to serve medical imaging AI models in production. +""" + +import logging +from contextlib import asynccontextmanager + +from fastapi import FastAPI, File, UploadFile, HTTPException, status +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from .inference import inference_engine +from .model_loader import model_loader +from .schemas import HealthResponse, PredictionResponse, ErrorResponse + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager for FastAPI app. + Handles startup and shutdown events. + """ + # Startup: Load the MONAI model + logger.info("Starting up: Loading MONAI model...") + try: + model_loader.load_model(model_name="spleen_ct_segmentation", bundle_dir="./models") + logger.info("Model loaded successfully!") + except Exception as e: + logger.error(f"Failed to load model: {e}") + # Continue anyway - model loading can be retried + + yield + + # Shutdown: Cleanup + logger.info("Shutting down...") + + +# Initialize FastAPI app +app = FastAPI( + title="MONAI Inference API", + description="REST API for deploying MONAI model bundles for medical image inference", + version="1.0.0", + docs_url="/docs", + redoc_url="/redoc", + lifespan=lifespan, +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # In production, specify exact origins + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.exception_handler(Exception) +async def global_exception_handler(request, exc): + """Global exception handler for unexpected errors.""" + logger.error(f"Unexpected error: {str(exc)}") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={ + "error": "InternalServerError", + "detail": "An unexpected error occurred. Please try again.", + "status_code": 500, + }, + ) + + +@app.get( + "/", + summary="Root endpoint", + description="Returns basic API information", +) +async def root(): + """Root endpoint - API information.""" + return { + "name": "MONAI Inference API", + "version": "1.0.0", + "description": "FastAPI deployment for MONAI models", + "endpoints": { + "health": "/health", + "predict": "/predict", + "docs": "/docs", + }, + } + + +@app.get( + "/health", + response_model=HealthResponse, + summary="Health check", + description="Check if the service and model are ready", +) +async def health_check(): + """ + Health check endpoint. + + Returns: + HealthResponse: Service and model status + """ + is_loaded = model_loader.is_loaded() + + return HealthResponse( + status="healthy" if is_loaded else "model_not_loaded", + model_loaded=is_loaded, + device=str(model_loader.device) if is_loaded else "unknown", + ) + + +@app.post( + "/predict", + response_model=PredictionResponse, + summary="Run inference", + description="Upload a medical image and get predictions", + responses={ + 200: {"description": "Successful prediction"}, + 400: {"model": ErrorResponse, "description": "Bad request"}, + 500: {"model": ErrorResponse, "description": "Internal server error"}, + }, +) +async def predict(file: UploadFile = File(..., description="Medical image file (NIfTI format: .nii or .nii.gz)")): + """ + Run inference on uploaded medical image. + + Args: + file: Uploaded image file (NIfTI format) + + Returns: + PredictionResponse: Prediction results with metadata + + Raises: + HTTPException: If file format is invalid or inference fails + """ + # Validate file format + if not file.filename.endswith((".nii", ".nii.gz")): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file format. Supported formats: .nii, .nii.gz" + ) + + # Check if model is loaded + if not model_loader.is_loaded(): + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model not loaded. Please try again later." + ) + + try: + # Read file content + contents = await file.read() + + # Run inference + result = await inference_engine.process_image(image_bytes=contents, filename=file.filename) + + return PredictionResponse(**result) + + except ValueError as e: + # Client error (bad input) + logger.warning(f"Bad request: {str(e)}") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + except RuntimeError as e: + # Server error (inference failed) + logger.error(f"Inference error: {str(e)}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Inference failed: {str(e)}") + + except Exception as e: + # Unexpected error + logger.error(f"Unexpected error during prediction: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during prediction" + ) + + +if __name__ == "__main__": + import uvicorn + + # For development only - use proper ASGI server in production + uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, log_level="info") diff --git a/deployment/fastapi_inference/app/model_loader.py b/deployment/fastapi_inference/app/model_loader.py new file mode 100644 index 000000000..a0002deca --- /dev/null +++ b/deployment/fastapi_inference/app/model_loader.py @@ -0,0 +1,107 @@ +""" +MONAI Model Loader + +This module implements a singleton pattern for loading and caching MONAI model bundles. +The model is loaded once at startup and reused for all inference requests. +""" + +import logging +from pathlib import Path +from typing import Optional + +import torch +from monai.bundle import download, load + +logger = logging.getLogger(__name__) + + +class ModelLoader: + """ + Singleton class for loading and managing MONAI model bundles. + + This ensures the model is loaded only once and reused across requests, + improving performance and resource utilization. + """ + + _instance: Optional["ModelLoader"] = None + _model = None + _device = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + """Initialize the model loader (called only once).""" + if self._model is None: + self._setup_device() + + def _setup_device(self): + """Determine and set up the computation device (CPU or GPU).""" + if torch.cuda.is_available(): + self._device = torch.device("cuda") + logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}") + else: + self._device = torch.device("cpu") + logger.info("Using CPU for inference") + + def load_model(self, model_name: str = "spleen_ct_segmentation", bundle_dir: str = "./models") -> None: + """ + Load a MONAI model bundle. + + Args: + model_name: Name of the MONAI bundle to load + bundle_dir: Directory to store/load the bundle + + Raises: + RuntimeError: If model loading fails + """ + try: + bundle_path = Path(bundle_dir) / model_name + + # Download bundle if not exists + if not bundle_path.exists(): + logger.info(f"Downloading model bundle: {model_name}") + download(name=model_name, bundle_dir=bundle_dir) + logger.info(f"Model downloaded successfully to {bundle_path}") + else: + logger.info(f"Using existing model bundle at {bundle_path}") + + # Load the model + logger.info("Loading model into memory...") + self._model = load(name=model_name, bundle_dir=bundle_dir, source="monaihosting") + + # Move model to device + if hasattr(self._model, "to"): + self._model = self._model.to(self._device) + + # Set model to evaluation mode + if hasattr(self._model, "eval"): + self._model.eval() + + logger.info("Model loaded successfully") + + except Exception as e: + logger.error(f"Failed to load model: {str(e)}") + raise RuntimeError(f"Model loading failed: {str(e)}") + + @property + def model(self): + """Get the loaded model instance.""" + if self._model is None: + raise RuntimeError("Model not loaded. Call load_model() first.") + return self._model + + @property + def device(self): + """Get the computation device.""" + return self._device + + def is_loaded(self) -> bool: + """Check if model is loaded.""" + return self._model is not None + + +# Global instance +model_loader = ModelLoader() diff --git a/deployment/fastapi_inference/app/schemas.py b/deployment/fastapi_inference/app/schemas.py new file mode 100644 index 000000000..4b7048d2c --- /dev/null +++ b/deployment/fastapi_inference/app/schemas.py @@ -0,0 +1,43 @@ +""" +Pydantic Models for Request/Response Validation + +This module defines the data structures for API requests and responses. +""" + +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field + + +class HealthResponse(BaseModel): + """Health check response model.""" + + status: str = Field(..., description="Service status") + model_loaded: bool = Field(..., description="Whether model is loaded") + device: str = Field(..., description="Computation device (CPU/GPU)") + + +class PredictionMetadata(BaseModel): + """Metadata about the prediction.""" + + image_shape: List[int] = Field(..., description="Input image dimensions") + processing_time: float = Field(..., description="Processing time in seconds") + device: str = Field(..., description="Device used for inference") + + +class PredictionResponse(BaseModel): + """Response model for inference predictions.""" + + success: bool = Field(..., description="Whether prediction was successful") + prediction: Optional[Dict] = Field(None, description="Prediction results (format depends on model output)") + segmentation_shape: Optional[List[int]] = Field(None, description="Shape of segmentation mask if applicable") + metadata: PredictionMetadata = Field(..., description="Prediction metadata") + message: Optional[str] = Field(None, description="Additional information or error message") + + +class ErrorResponse(BaseModel): + """Error response model.""" + + error: str = Field(..., description="Error type") + detail: str = Field(..., description="Detailed error message") + status_code: int = Field(..., description="HTTP status code") diff --git a/deployment/fastapi_inference/docker/Dockerfile b/deployment/fastapi_inference/docker/Dockerfile new file mode 100644 index 000000000..6968b4969 --- /dev/null +++ b/deployment/fastapi_inference/docker/Dockerfile @@ -0,0 +1,49 @@ +# MONAI FastAPI Inference Service +# Multi-stage Dockerfile for production deployment + +# Stage 1: Builder +FROM python:3.10-slim as builder + +WORKDIR /build + +# Install build dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + g++ \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements and install Python dependencies +COPY requirements.txt . +RUN pip install --user --no-cache-dir -r requirements.txt + +# Stage 2: Runtime +FROM python:3.10-slim + +WORKDIR /app + +# Install runtime dependencies +RUN apt-get update && apt-get install -y \ + libgomp1 \ + && rm -rf /var/lib/apt/lists/* + +# Copy Python dependencies from builder +COPY --from=builder /root/.local /root/.local + +# Make sure scripts in .local are usable +ENV PATH=/root/.local/bin:$PATH + +# Copy application code +COPY app/ /app/app/ + +# Create directory for model storage +RUN mkdir -p /app/models + +# Expose port +EXPOSE 8000 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \ + CMD python -c "import requests; requests.get('http://localhost:8000/health')" || exit 1 + +# Run the application +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/deployment/fastapi_inference/docker/docker-compose.yml b/deployment/fastapi_inference/docker/docker-compose.yml new file mode 100644 index 000000000..0d4e13b28 --- /dev/null +++ b/deployment/fastapi_inference/docker/docker-compose.yml @@ -0,0 +1,25 @@ +version: '3.8' + +services: + monai-api: + build: + context: .. + dockerfile: docker/Dockerfile + container_name: monai-fastapi-inference + ports: + - "8000:8000" + volumes: + # Mount models directory for persistence + - ../models:/app/models + # Mount app directory for development (comment out in production) + - ../app:/app/app + environment: + - PYTHONUNBUFFERED=1 + - LOG_LEVEL=info + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 60s diff --git a/deployment/fastapi_inference/examples/client.py b/deployment/fastapi_inference/examples/client.py new file mode 100644 index 000000000..430478ebc --- /dev/null +++ b/deployment/fastapi_inference/examples/client.py @@ -0,0 +1,105 @@ +""" +Example Python Client for MONAI FastAPI Inference Service + +This script demonstrates how to interact with the deployed MONAI inference API. +""" + +import argparse +import json +from pathlib import Path + +import requests + + +class MONAIClient: + """Client for interacting with MONAI FastAPI inference service.""" + + def __init__(self, base_url: str = "http://localhost:8000"): + """ + Initialize the client. + + Args: + base_url: Base URL of the API (default: http://localhost:8000) + """ + self.base_url = base_url.rstrip("/") + + def health_check(self) -> dict: + """ + Check if the service is healthy. + + Returns: + Health status dictionary + """ + response = requests.get(f"{self.base_url}/health") + response.raise_for_status() + return response.json() + + def predict(self, image_path: str) -> dict: + """ + Send an image for inference. + + Args: + image_path: Path to the medical image file (.nii or .nii.gz) + + Returns: + Prediction results dictionary + + Raises: + FileNotFoundError: If image file doesn't exist + requests.HTTPError: If API request fails + """ + image_path = Path(image_path) + + if not image_path.exists(): + raise FileNotFoundError(f"Image not found: {image_path}") + + with open(image_path, "rb") as f: + files = {"file": (image_path.name, f, "application/octet-stream")} + response = requests.post(f"{self.base_url}/predict", files=files) + + response.raise_for_status() + return response.json() + + +def main(): + """Main function for command-line usage.""" + parser = argparse.ArgumentParser(description="MONAI FastAPI Inference Client") + parser.add_argument("--url", default="http://localhost:8000", help="API base URL (default: http://localhost:8000)") + parser.add_argument("--health", action="store_true", help="Check API health status") + parser.add_argument("--image", type=str, help="Path to medical image file for prediction") + + args = parser.parse_args() + + # Initialize client + client = MONAIClient(base_url=args.url) + + # Health check + if args.health: + print("Checking API health...") + try: + health = client.health_check() + print(json.dumps(health, indent=2)) + except requests.RequestException as e: + print(f"Error: {e}") + return 1 + + # Prediction + if args.image: + print(f"Sending image for prediction: {args.image}") + try: + result = client.predict(args.image) + print("\nPrediction Results:") + print(json.dumps(result, indent=2)) + except FileNotFoundError as e: + print(f"Error: {e}") + return 1 + except requests.HTTPError as e: + print(f"API Error: {e}") + print(f"Response: {e.response.text}") + return 1 + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/deployment/fastapi_inference/examples/sample_requests.http b/deployment/fastapi_inference/examples/sample_requests.http new file mode 100644 index 000000000..369442ab1 --- /dev/null +++ b/deployment/fastapi_inference/examples/sample_requests.http @@ -0,0 +1,36 @@ +### MONAI FastAPI Inference Service - Example Requests +### Use with REST Client extensions in VS Code or similar tools + +@baseUrl = http://localhost:8000 + +### 1. Root endpoint - Get API information +GET {{baseUrl}}/ HTTP/1.1 + +### 2. Health check - Check service and model status +GET {{baseUrl}}/health HTTP/1.1 + +### 3. API Documentation - Open in browser +# GET {{baseUrl}}/docs + +### 4. Predict - Upload medical image for inference +# Replace 'path/to/image.nii.gz' with actual file path +POST {{baseUrl}}/predict HTTP/1.1 +Content-Type: multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW + +------WebKitFormBoundary7MA4YWxkTrZu0gW +Content-Disposition: form-data; name="file"; filename="sample.nii.gz" +Content-Type: application/octet-stream + +< ./tests/sample_image.nii.gz +------WebKitFormBoundary7MA4YWxkTrZu0gW-- + +### 5. Invalid file format - Should return 400 error +POST {{baseUrl}}/predict HTTP/1.1 +Content-Type: multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW + +------WebKitFormBoundary7MA4YWxkTrZu0gW +Content-Disposition: form-data; name="file"; filename="invalid.txt" +Content-Type: text/plain + +Invalid file content +------WebKitFormBoundary7MA4YWxkTrZu0gW-- diff --git a/deployment/fastapi_inference/requirements.txt b/deployment/fastapi_inference/requirements.txt new file mode 100644 index 000000000..af3010cf4 --- /dev/null +++ b/deployment/fastapi_inference/requirements.txt @@ -0,0 +1,23 @@ +# FastAPI and server +fastapi==0.104.1 +uvicorn[standard]==0.24.0 +python-multipart==0.0.6 + +# MONAI and ML +monai[all]==1.3.0 +torch==2.1.0 +torchvision==0.16.0 + +# Image processing +nibabel==5.1.0 +SimpleITK==2.3.0 +Pillow==10.1.0 + +# Utilities +pydantic==2.5.0 +python-dotenv==1.0.0 + +# Testing (optional for dev) +pytest==7.4.3 +pytest-asyncio==0.21.1 +httpx==0.25.1 diff --git a/deployment/fastapi_inference/tests/__init__.py b/deployment/fastapi_inference/tests/__init__.py new file mode 100644 index 000000000..10b1f8c37 --- /dev/null +++ b/deployment/fastapi_inference/tests/__init__.py @@ -0,0 +1,3 @@ +""" +Test suite for MONAI FastAPI Inference Service +""" diff --git a/deployment/fastapi_inference/tests/test_api.py b/deployment/fastapi_inference/tests/test_api.py new file mode 100644 index 000000000..217bd3985 --- /dev/null +++ b/deployment/fastapi_inference/tests/test_api.py @@ -0,0 +1,112 @@ +""" +API Endpoint Tests + +Tests for FastAPI endpoints including health checks and prediction. +""" + +import pytest +from fastapi.testclient import TestClient + +from app.main import app + +client = TestClient(app) + + +class TestRootEndpoint: + """Tests for root endpoint.""" + + def test_root_returns_200(self): + """Test that root endpoint returns 200 OK.""" + response = client.get("/") + assert response.status_code == 200 + + def test_root_returns_api_info(self): + """Test that root endpoint returns API information.""" + response = client.get("/") + data = response.json() + + assert "name" in data + assert "version" in data + assert "endpoints" in data + assert data["name"] == "MONAI Inference API" + + +class TestHealthEndpoint: + """Tests for health check endpoint.""" + + def test_health_returns_200(self): + """Test that health endpoint returns 200 OK.""" + response = client.get("/health") + assert response.status_code == 200 + + def test_health_returns_status(self): + """Test that health endpoint returns status information.""" + response = client.get("/health") + data = response.json() + + assert "status" in data + assert "model_loaded" in data + assert "device" in data + + def test_health_status_format(self): + """Test that health response has expected format.""" + response = client.get("/health") + data = response.json() + + assert isinstance(data["model_loaded"], bool) + assert isinstance(data["device"], str) + assert data["status"] in ["healthy", "model_not_loaded"] + + +class TestPredictEndpoint: + """Tests for prediction endpoint.""" + + def test_predict_requires_file(self): + """Test that predict endpoint requires a file.""" + response = client.post("/predict") + assert response.status_code == 422 # Unprocessable Entity + + def test_predict_rejects_invalid_format(self): + """Test that predict endpoint rejects non-NIfTI files.""" + # Create a fake file with wrong extension + files = {"file": ("test.txt", b"fake content", "text/plain")} + response = client.post("/predict", files=files) + + assert response.status_code == 400 + assert "Invalid file format" in response.json()["detail"] + + def test_predict_accepts_nifti_extension(self): + """Test that predict endpoint accepts .nii files.""" + # Note: This will fail inference if model not loaded, + # but should pass file validation + files = {"file": ("test.nii", b"fake nifti data", "application/octet-stream")} + response = client.post("/predict", files=files) + + # Should get past file validation (not 400) + # May get 503 (model not loaded) or 500 (invalid data) + assert response.status_code in [500, 503] + + +class TestDocumentation: + """Tests for API documentation endpoints.""" + + def test_docs_available(self): + """Test that Swagger UI docs are available.""" + response = client.get("/docs") + assert response.status_code == 200 + + def test_redoc_available(self): + """Test that ReDoc documentation is available.""" + response = client.get("/redoc") + assert response.status_code == 200 + + +class TestCORS: + """Tests for CORS middleware.""" + + def test_cors_headers_present(self): + """Test that CORS headers are present in responses.""" + response = client.get("/health") + + # CORS headers should be present + assert "access-control-allow-origin" in response.headers