diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7aa85a6f..1068be8b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,6 +24,7 @@ repos: name: check-added-large-files entry: check-added-large-files language: system + args: ['--maxkb=2048'] - id: ruff-check name: ruff-check entry: ruff check diff --git a/auto_tutorial_source/Bayesian_Methods/tuto_ood.py b/auto_tutorial_source/Bayesian_Methods/tuto_ood.py new file mode 100644 index 00000000..05f914df --- /dev/null +++ b/auto_tutorial_source/Bayesian_Methods/tuto_ood.py @@ -0,0 +1,224 @@ +""" +Simple Ood Evaluation +================================================ + + +In this tutorial, we’ll demonstrate how to perform out-of-distribution (OOD) evaluation using TorchUncertainty’s datamodules and routines. You’ll learn to: + +1. **Set up a CIFAR-100 datamodule** that automatically handles in-distribution, near-OOD, and far-OOD splits. +2. **Run the `ClassificationRoutine`** to compute both in-distribution accuracy and OOD metrics (AUROC, AUPR, FPR95). +3. **Plug in your own OOD datasets** for fully custom evaluation. + +Foreword on Out-of-Distribution Detection +----------------------------------------- + +Out-of-Distribution (OOD) detection measures a model’s ability to recognize inputs that differ from its training distribution. TorchUncertainty integrates common OOD metrics directly into the Lightning test loop, including: + +- **AUROC** (Area Under the ROC Curve) +- **AUPR** (Area Under the Precision-Recall Curve) +- **FPR95** (False Positive Rate at 95% True Positive Rate) + +With just a few lines of code you can compare in-distribution performance to OOD detection performance under both “near” and “far” shifts. Per default, TorchUncertainty uses the +popular OpenOOD library to define the near and far OOD datasets and splits. You can also use your own datasets by passing them to the datamodule. + +Supported Datamodules and Default OOD Splits +-------------------------------------------- + +.. list-table:: Datamodules & Default OOD Splits + :header-rows: 1 + :widths: 20 15 20 20 + + * - **Datamodule** + - **In-Domain** + - **Default Near-OOD (Hard)** + - **Default Far-OOD (Easy)** + * - ``CIFAR10DataModule`` + - CIFAR-10 + - CIFAR-100, Tiny ImageNet + - MNIST, SVHN, Textures, Places365 + * - ``CIFAR100DataModule`` + - CIFAR-100 + - CIFAR-10, Tiny ImageNet + - MNIST, SVHN, Textures, Places365 + * - ``ImageNetDataModule`` + - ImageNet-1K + - SSB-hard, NINCO + - iNaturalist, Textures, OpenImage-O + * - ``ImageNet200DataModule`` + - ImageNet200 + - SSB-hard, NINCO + - iNaturalist, Textures, OpenImage-O + +Supported OOD Criteria +---------------------- + +.. list-table:: Supported OOD Criteria + :header-rows: 1 + :widths: 15 50 + + * - **Criterion** + - **Original Reference (Year, Venue)** + * - ``msp`` + - Hendrycks & Gimpel, A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks `ICLR Workshop 2017 `_. + * - ``Maxlogit`` + - / + * - ``energy`` + - Liu et al., Energy-based Out-of-Distribution Detection `NeurIPS 2020 `_. + * - ``odin`` + - Liang, Li & Srikant, Enhancing The Reliability of Out-of-Distribution Image Detection in Neural Networks `ICML 2018 `_. + * - ``entropy`` + - / + * - ``mutual_information`` + - / + * - ``variation_ratio`` + - / + * - ``scale`` + - Scaling Out-of-Distribution Detection for Real-World Settings Hendrycks et al. `ICML 2022 `_. + * - ``ash`` + - AASH: Extremely Simple Activation Shaping for OOD Detection, Djurisic et al. `ICLR 2023 `_. + * - ``react`` + - ReAct: Out-of-distribution Detection with Rectified Activations, Sun et al. `NeurIPS 2021 `_. + * - ``adascale_a`` + - AdaSCALE: Adaptive Scaling for OOD Detection `Regmi et al. `_. + * - ``vim`` + - ViM: Out-of-Distribution with Virtual-Logit Matching, Wang et al. `CVPR 2022 `_. + * - ``knn`` + - Out-of-Distribution Detection with Deep Nearest Neighbors, Sun et al. `ICML 2022 `_. + * - ``gen`` + - GEN: Generalized ENtropy Score for OOD Detection, Liu et al. `CVPR 2023 `_. + * - ``nnguide`` + - NNGuide: Nearest-Neighbor Guidance for OOD Detection, Park et al. `ICCV 2023 `_. + +.. note:: + + - All of these criteria can be passed as the `ood_criterion` argument to + `ClassificationRoutine`. + - Methods marked “ensemble-only” will require multiple stochastic passes. + + + +.. note:: + + - **Near-OOD** splits are semantically similar to the in-domain data. + - **Far-OOD** splits come from more distant distributions (e.g., ImageNet variants). + - Override defaults by passing your own ``near_ood_datasets`` / ``far_ood_datasets``. + + +1. Loading the utilities +~~~~~~~~~~~~~~~~~~~~~~~~ + +To eval ood using TorchUncertainty, we have to load the following: + +- the model:ResNet18_32x32 trained on in-domain data cifar100 +- the classification routine from torch_uncertainty.routines +- the datamodule that handles dataloaders: CIFAR100DataModule from torch_uncertainty.datamodules. +""" + +# %% +from pathlib import Path + +# %% +# 2. Load the trained model +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# In this tutorial we will be loading a pretrained model, but you can also train your own using the same classification routine and still get ood related metrics at test phase. + + +import torch +from torch_uncertainty.models.resnet import resnet +from huggingface_hub import hf_hub_download + +net = resnet(in_channels=3, arch=18, num_classes=100, style="cifar", conv_bias=False) + +# load the model +path = hf_hub_download(repo_id="torch-uncertainty/resnet18_c100", filename="resnet18_c100.ckpt") +net.load_state_dict(torch.load(path)) + +net.cuda() +net.eval() + + +# %% +# 3. Defining the necessary datamodules +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In the following, we instantiate our trainer, define the root of the datasets and the logs. +# We also create the datamodule that handles the cifar100 dataset, dataloaders and transforms. +# Datamodules can also handle OOD detection by setting the eval_ood parameter to True. + +from torch_uncertainty.datamodules import CIFAR100DataModule +from torch_uncertainty.routines import ClassificationRoutine +import torch.nn as nn +from pathlib import Path +from torch_uncertainty import TUTrainer + + +root = Path("data1") +datamodule = CIFAR100DataModule(root=root, batch_size=200, eval_ood=True, eval_shift=True) +trainer = TUTrainer(accelerator="gpu", enable_progress_bar=True) + + +# %% +# 4. Define the classification routine and launch the test +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Define the classification routine for evaluation. We use the CrossEntropyLoss +# as the loss function since we are working on a classification task. +# The routine is configured to handle OOD detection and distributional shifts using the specified model, loss function, and evaluation criteria. + +routine = ClassificationRoutine( + num_classes=datamodule.num_classes, + eval_ood=True, + model=net, + loss=nn.CrossEntropyLoss(), + eval_shift=True, + ood_criterion="ash", +) + +# Perform testing using the defined routine and datamodule. +results = trainer.test(model=routine, datamodule=datamodule) + + +# %% +# Here, we show the various test metrics along with the ood eval metrics, auroc,aupr and fpr95 on Near and far ood datasets defined per defualt according to OpenOOD splits (link to library) + + +# %% +# 5. Defining custom ood datasets +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# If you don't want to use the open ood datasets or dataset splits, you can pass your own datasets in a list to near_ood_datasets or far_ood_datasets datamodule arguments +# and use them for ood evaluation but make sure they inherit from the +# Dataset class from torch.utils.data, below is an example of such a case. + +from torchvision.datasets import CIFAR10, MNIST +from torchvision.transforms import v2 + + +test_transform = v2.Compose( + [ + v2.ToImage(), + v2.Resize(32), + v2.CenterCrop(32), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.5071, 0.4867, 0.4408)), + ] +) + +custom_dataset1 = CIFAR10(root=root, train=False, download=True, transform=test_transform) +custom_dataset2 = MNIST(root=root, train=False, download=True, transform=test_transform) + +datamodule = CIFAR100DataModule( + root=root, + batch_size=200, + eval_ood=True, + eval_shift=True, + near_ood_datasets=[custom_dataset1], + far_ood_datasets=[custom_dataset2], +) + +# Perform testing using the CUSTOM defined ood datasets. +results = trainer.test(model=routine, datamodule=datamodule) + + +# %% +# References +# ---------- +# - **OpenOOD:** Jingyang Zhang & al. (`Neurips 2025 `_). OpenOOD v1.5: Enhanced Benchmark for Out-of-Distribution Detection. diff --git a/auto_tutorial_source/Classification/tutorial_bert.py b/auto_tutorial_source/Classification/tutorial_bert.py new file mode 100644 index 00000000..886b6d5c --- /dev/null +++ b/auto_tutorial_source/Classification/tutorial_bert.py @@ -0,0 +1,153 @@ +""" +Benchamrk bert with torch-uncertainty on SST2 +=============================================== + +This tutorial is about using torch-uncertainty to benchmark a bert model on the sst2 dataset with various robustness metricis +and apply easily a postprocess step (MC dropout) on top either of the single model or deep ensemble. + +Dataset +------- + +In this tutorial we will use sst2 dataset available directly through torch uncertainty a long with various far/near ood datasets +also handled automatically by torch-uncertainty. + + +1. Define and load the single bert model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +""" + +# %% +import torch +import torch.nn as nn +from collections import OrderedDict +from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer, AutoModelForSequenceClassification + + +def load_tu_ckpt_into_hf( + backbone, repo_id: str, filename: str, strict: bool = True, map_location="cpu" +): + ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) + + sd = torch.load(ckpt_path, map_location=map_location) + sd = sd.get("state_dict", sd) + + def with_prefix(prefix): + return OrderedDict((k[len(prefix) :], v) for k, v in sd.items() if k.startswith(prefix)) + + for pref in ("model.backbone.", "model.", "backbone."): + sub = with_prefix(pref) + if sub: + return backbone.load_state_dict(sub, strict=strict) + + return backbone.load_state_dict(sd, strict=strict) + + +class HFClassifier(nn.Module): + def __init__(self, model_name: str, num_labels: int = 2, local_files_only: bool = False): + super().__init__() + self.backbone = AutoModelForSequenceClassification.from_pretrained( + model_name, num_labels=num_labels, local_files_only=local_files_only + ) + + def forward(self, *args, **kwargs): + inputs = args[0] if (len(args) == 1 and isinstance(args[0], dict)) else kwargs + return self.backbone(**inputs).logits + + +net1 = HFClassifier("bert-base-uncased", num_labels=2) + +load_tu_ckpt_into_hf( + net1.backbone, + repo_id="torch-uncertainty/bert-sst2", + filename="model1.ckpt", +) + + +# %% +# 2. Benchmark the single model +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# We define first the sst2 datamodule then run the classification routine as follows. + +from torch_uncertainty.routines import ClassificationRoutine +from torch_uncertainty import TUTrainer +from torch_uncertainty.datamodules import Sst2DataModule + + +dm = Sst2DataModule( + batch_size=64, + eval_ood=True, +) + +trainer = TUTrainer(accelerator="gpu", enable_progress_bar=True, devices=1) + +routine = ClassificationRoutine( + num_classes=2, + model=net1, + loss=nn.CrossEntropyLoss(), + eval_ood=True, +) + +res = trainer.test(routine, datamodule=dm) + + +# %% +# 3. Apply a postprocess step on top of the single model +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Here we will be applying for example montecarlo dropout on top of the single model. +# but torch-uncertainty supports many other postprocess like temperature scaling,conformal... please refer to the documentation. + +from torch_uncertainty.models import mc_dropout + +mc_net = mc_dropout( + model=net1, + num_estimators=8, + on_batch=False, +) + +routine = ClassificationRoutine( + num_classes=2, + model=mc_net, + loss=nn.CrossEntropyLoss(), + eval_ood=True, +) + +res = trainer.test(routine, datamodule=dm) + + +# %% +# 4. Load and benchmark a deep ensemble of bert models +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Let us load the remaining models of the deep ensemble and then benchmark them easily with torch unceratinty. + +net2 = HFClassifier("bert-base-uncased", num_labels=2) + +load_tu_ckpt_into_hf( + net2.backbone, + repo_id="torch-uncertainty/bert-sst2", + filename="model2.ckpt", +) + + +net3 = HFClassifier("bert-base-uncased", num_labels=2) + +load_tu_ckpt_into_hf( + net3.backbone, + repo_id="torch-uncertainty/bert-sst2", + filename="model3.ckpt", +) + + +from torch_uncertainty.models import deep_ensembles + +deep = deep_ensembles([net1, net2, net3]) + + +routine = ClassificationRoutine( + num_classes=2, + model=deep, + loss=nn.CrossEntropyLoss(), + eval_ood=True, +) +res = trainer.test(routine, datamodule=dm) diff --git a/auto_tutorial_source/Classification/tutorial_ood_detection.py b/auto_tutorial_source/Classification/tutorial_ood_detection.py index 119baa0d..05f914df 100644 --- a/auto_tutorial_source/Classification/tutorial_ood_detection.py +++ b/auto_tutorial_source/Classification/tutorial_ood_detection.py @@ -1,150 +1,224 @@ -# ruff: noqa: E402, D212, D415 """ -Out-of-distribution detection with TorchUncertainty -=================================================== +Simple Ood Evaluation +================================================ + + +In this tutorial, we’ll demonstrate how to perform out-of-distribution (OOD) evaluation using TorchUncertainty’s datamodules and routines. You’ll learn to: + +1. **Set up a CIFAR-100 datamodule** that automatically handles in-distribution, near-OOD, and far-OOD splits. +2. **Run the `ClassificationRoutine`** to compute both in-distribution accuracy and OOD metrics (AUROC, AUPR, FPR95). +3. **Plug in your own OOD datasets** for fully custom evaluation. + +Foreword on Out-of-Distribution Detection +----------------------------------------- + +Out-of-Distribution (OOD) detection measures a model’s ability to recognize inputs that differ from its training distribution. TorchUncertainty integrates common OOD metrics directly into the Lightning test loop, including: + +- **AUROC** (Area Under the ROC Curve) +- **AUPR** (Area Under the Precision-Recall Curve) +- **FPR95** (False Positive Rate at 95% True Positive Rate) + +With just a few lines of code you can compare in-distribution performance to OOD detection performance under both “near” and “far” shifts. Per default, TorchUncertainty uses the +popular OpenOOD library to define the near and far OOD datasets and splits. You can also use your own datasets by passing them to the datamodule. + +Supported Datamodules and Default OOD Splits +-------------------------------------------- + +.. list-table:: Datamodules & Default OOD Splits + :header-rows: 1 + :widths: 20 15 20 20 + + * - **Datamodule** + - **In-Domain** + - **Default Near-OOD (Hard)** + - **Default Far-OOD (Easy)** + * - ``CIFAR10DataModule`` + - CIFAR-10 + - CIFAR-100, Tiny ImageNet + - MNIST, SVHN, Textures, Places365 + * - ``CIFAR100DataModule`` + - CIFAR-100 + - CIFAR-10, Tiny ImageNet + - MNIST, SVHN, Textures, Places365 + * - ``ImageNetDataModule`` + - ImageNet-1K + - SSB-hard, NINCO + - iNaturalist, Textures, OpenImage-O + * - ``ImageNet200DataModule`` + - ImageNet200 + - SSB-hard, NINCO + - iNaturalist, Textures, OpenImage-O + +Supported OOD Criteria +---------------------- + +.. list-table:: Supported OOD Criteria + :header-rows: 1 + :widths: 15 50 + + * - **Criterion** + - **Original Reference (Year, Venue)** + * - ``msp`` + - Hendrycks & Gimpel, A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks `ICLR Workshop 2017 `_. + * - ``Maxlogit`` + - / + * - ``energy`` + - Liu et al., Energy-based Out-of-Distribution Detection `NeurIPS 2020 `_. + * - ``odin`` + - Liang, Li & Srikant, Enhancing The Reliability of Out-of-Distribution Image Detection in Neural Networks `ICML 2018 `_. + * - ``entropy`` + - / + * - ``mutual_information`` + - / + * - ``variation_ratio`` + - / + * - ``scale`` + - Scaling Out-of-Distribution Detection for Real-World Settings Hendrycks et al. `ICML 2022 `_. + * - ``ash`` + - AASH: Extremely Simple Activation Shaping for OOD Detection, Djurisic et al. `ICLR 2023 `_. + * - ``react`` + - ReAct: Out-of-distribution Detection with Rectified Activations, Sun et al. `NeurIPS 2021 `_. + * - ``adascale_a`` + - AdaSCALE: Adaptive Scaling for OOD Detection `Regmi et al. `_. + * - ``vim`` + - ViM: Out-of-Distribution with Virtual-Logit Matching, Wang et al. `CVPR 2022 `_. + * - ``knn`` + - Out-of-Distribution Detection with Deep Nearest Neighbors, Sun et al. `ICML 2022 `_. + * - ``gen`` + - GEN: Generalized ENtropy Score for OOD Detection, Liu et al. `CVPR 2023 `_. + * - ``nnguide`` + - NNGuide: Nearest-Neighbor Guidance for OOD Detection, Park et al. `ICCV 2023 `_. + +.. note:: + + - All of these criteria can be passed as the `ood_criterion` argument to + `ClassificationRoutine`. + - Methods marked “ensemble-only” will require multiple stochastic passes. + + + +.. note:: + + - **Near-OOD** splits are semantically similar to the in-domain data. + - **Far-OOD** splits come from more distant distributions (e.g., ImageNet variants). + - Override defaults by passing your own ``near_ood_datasets`` / ``far_ood_datasets``. + + +1. Loading the utilities +~~~~~~~~~~~~~~~~~~~~~~~~ + +To eval ood using TorchUncertainty, we have to load the following: + +- the model:ResNet18_32x32 trained on in-domain data cifar100 +- the classification routine from torch_uncertainty.routines +- the datamodule that handles dataloaders: CIFAR100DataModule from torch_uncertainty.datamodules. +""" -This tutorial demonstrates how to perform OOD detection using -TorchUncertainty's ClassificationRoutine with a ResNet18 model trained on CIFAR-10, -evaluating its performance with SVHN as the OOD dataset. +# %% +from pathlib import Path -We will: +# %% +# 2. Load the trained model +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# In this tutorial we will be loading a pretrained model, but you can also train your own using the same classification routine and still get ood related metrics at test phase. -- Set up the CIFAR-10 datamodule. -- Initialize and shortly train a ResNet18 model using the ClassificationRoutine. -- Evaluate the model's performance on both in-distribution and out-of-distribution data. -- Analyze uncertainty metrics for OOD detection. -""" -# %% -# Imports and Setup -# ------------------ -# -# First, we need to import the necessary libraries and set up our environment. -# This includes importing PyTorch, TorchUncertainty components, and TorchUncertainty's Trainer (built on top of Lightning's), -# as well as two criteria for OOD detection, the maximum softmax probability [1] and the Max Logit [2]. -from torch import nn, optim +import torch +from torch_uncertainty.models.resnet import resnet +from huggingface_hub import hf_hub_download -from torch_uncertainty import TUTrainer -from torch_uncertainty.datamodules import CIFAR10DataModule -from torch_uncertainty.models.classification.resnet import resnet -from torch_uncertainty.ood_criteria import MaxLogitCriterion, MaxSoftmaxCriterion -from torch_uncertainty.routines.classification import ClassificationRoutine +net = resnet(in_channels=3, arch=18, num_classes=100, style="cifar", conv_bias=False) -# %% -# DataModule Setup -# ---------------- -# -# TorchUncertainty provides convenient DataModules for standard datasets like CIFAR-10. -# DataModules handle data loading, preprocessing, and batching, simplifying the data pipeline. Each datamodule -# also include the corresponding out-of-distribution and distribution shift datasets, which are then used by the routine. -# For CIFAR-10, the corresponding OOD-detection dataset is SVHN as used in the community. -# To enable OOD evaluation, activate the `eval_ood` flag as done below. +# load the model +path = hf_hub_download(repo_id="torch-uncertainty/resnet18_c100", filename="resnet18_c100.ckpt") +net.load_state_dict(torch.load(path)) + +net.cuda() +net.eval() -datamodule = CIFAR10DataModule(root="./data", batch_size=512, num_workers=8, eval_ood=True) # %% -# Model Initialization -# -------------------- +# 3. Defining the necessary datamodules +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# We use the ResNet18 architecture, a widely adopted convolutional neural network known for its deep residual learning capabilities. -# The model is initialized with 10 output classes corresponding to the CIFAR-10 dataset categories. When training on CIFAR, do not forget to -# set the style of the resnet to CIFAR, otherwise it will lose more information in the first convolution. +# In the following, we instantiate our trainer, define the root of the datasets and the logs. +# We also create the datamodule that handles the cifar100 dataset, dataloaders and transforms. +# Datamodules can also handle OOD detection by setting the eval_ood parameter to True. + +from torch_uncertainty.datamodules import CIFAR100DataModule +from torch_uncertainty.routines import ClassificationRoutine +import torch.nn as nn +from pathlib import Path +from torch_uncertainty import TUTrainer -# Initialize the ResNet18 model -model = resnet(arch=18, in_channels=3, num_classes=10, style="cifar", conv_bias=False) -# %% -# Define the Classification Routine -# --------------------------------- -# -# The `ClassificationRoutine` is one of the most crucial building blocks in TorchUncertainty. -# It streamlines the training and evaluation processes. -# It integrates the model, loss function, and optimizer into a cohesive routine compatible with PyTorch Lightning's Trainer. -# This abstraction simplifies the implementation of standard training loops and evaluation protocols. -# To come back to what matters in this tutorial, the routine also handles OOD detection. To enable it, -# just activate the `eval_ood` flag. Note that you can also evaluate the distribution-shift performance -# of the model at the same time by also setting `eval_shift` to True. +root = Path("data1") +datamodule = CIFAR100DataModule(root=root, batch_size=200, eval_ood=True, eval_shift=True) +trainer = TUTrainer(accelerator="gpu", enable_progress_bar=True) -# Loss function -criterion = nn.CrossEntropyLoss() -# Optimizer -optimizer = optim.Adam(model.parameters(), lr=0.001) +# %% +# 4. Define the classification routine and launch the test +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Define the classification routine for evaluation. We use the CrossEntropyLoss +# as the loss function since we are working on a classification task. +# The routine is configured to handle OOD detection and distributional shifts using the specified model, loss function, and evaluation criteria. -# Initialize the ClassificationRoutine, you could replace MaxSoftmaxCriterion by "msp" routine = ClassificationRoutine( - model=model, - num_classes=10, - loss=criterion, - optim_recipe=optimizer, + num_classes=datamodule.num_classes, eval_ood=True, - ood_criterion=MaxSoftmaxCriterion, + model=net, + loss=nn.CrossEntropyLoss(), + eval_shift=True, + ood_criterion="ash", ) -# %% -# Test the Training of the Model -# ------------------------------ -# -# With the routine defined, we can now set up the Trainer and commence training. -# The Trainer handles the training loop, including epoch management, logging, and checkpointing. -# We specify the maximum number of epochs, the precision and the device to be used. To reduce the tutorial building time, -# we will train for a single epoch and load a model from `TorchUncertainty's HuggingFace `_. - -# Initialize the TUTrainer -trainer = TUTrainer( - max_epochs=1, precision="16-mixed", accelerator="cuda", devices=1, enable_progress_bar=False -) +# Perform testing using the defined routine and datamodule. +results = trainer.test(model=routine, datamodule=datamodule) -# Train the model for 1 epoch using the CIFAR-10 DataModule -trainer.fit(routine, datamodule=datamodule) # %% -# Load the model from HuggingFace -# ------------------------------- -# -# We simply download a ResNet-18 trained on CIFAR-10 from `TorchUncertainty's HuggingFace `_ and load it with -# the `load_from_checkpoint` method. +# Here, we show the various test metrics along with the ood eval metrics, auroc,aupr and fpr95 on Near and far ood datasets defined per defualt according to OpenOOD splits (link to library) -import torch -from huggingface_hub import hf_hub_download -path = hf_hub_download( - repo_id="torch-uncertainty/resnet18_c10", - filename="resnet18_c10.ckpt", +# %% +# 5. Defining custom ood datasets +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# If you don't want to use the open ood datasets or dataset splits, you can pass your own datasets in a list to near_ood_datasets or far_ood_datasets datamodule arguments +# and use them for ood evaluation but make sure they inherit from the +# Dataset class from torch.utils.data, below is an example of such a case. + +from torchvision.datasets import CIFAR10, MNIST +from torchvision.transforms import v2 + + +test_transform = v2.Compose( + [ + v2.ToImage(), + v2.Resize(32), + v2.CenterCrop(32), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.5071, 0.4867, 0.4408)), + ] ) -state_dict = torch.load(path, map_location="cpu", weights_only=True) -routine.model.load_state_dict(state_dict) -# %% -# Evaluating on In-Distribution and Out-of-distribution Data -# ---------------------------------------------------------- -# -# Now that the model is trained, we can evaluate its performance on the original in-distribution test set, -# as well as the OOD set. Typing the next line will automatically compute the in-distribution and OOD detection metrics. +custom_dataset1 = CIFAR10(root=root, train=False, download=True, transform=test_transform) +custom_dataset2 = MNIST(root=root, train=False, download=True, transform=test_transform) -# Evaluate the model on the CIFAR-10 (IID) and SVHN (OOD) test sets -results = trainer.test(routine, datamodule=datamodule) +datamodule = CIFAR100DataModule( + root=root, + batch_size=200, + eval_ood=True, + eval_shift=True, + near_ood_datasets=[custom_dataset1], + far_ood_datasets=[custom_dataset2], +) -# %% -# Changing the OOD Criterion -# -------------------------- -# -# The previous metrics for Out-of-distribution detection have been computed using the maximum softmax probability score [1], -# which corresponds to the likelihood of the prediction. We could use other scores such as the maximum logit [2]. To do this, -# just change the routine's `ood_criterion` and perform a second test. -routine.ood_criterion = MaxLogitCriterion() +# Perform testing using the CUSTOM defined ood datasets. +results = trainer.test(model=routine, datamodule=datamodule) -results = trainer.test(routine, datamodule=datamodule) # %% -# Note that you could create your own class if you want to implement a custom OOD detection score. When changing the -# Out-of-distribution criterion, all the In-distribution metrics remain the same. The only values that change -# are those of the regrouped in the OOD Detection category. Here we see that the AUPR, AUROC and FPR95 are worse using the maximum -# logit score compared to the maximum softmax probability but it could depend on the model you are using. -# # References # ---------- -# -# [1] Hendrycks, D., & Gimpel, K. (2016). A baseline for detecting misclassified and out-of-distribution examples in neural networks. In ICLR 2017. -# -# [2] Hendrycks, D., Basart, S., Mazeika, M., Zou, A., Kwon, J., Mostajabi, M., ... & Song, D. (2019). Scaling out-of-distribution detection for real-world settings. In ICML 2022. +# - **OpenOOD:** Jingyang Zhang & al. (`Neurips 2025 `_). OpenOOD v1.5: Enhanced Benchmark for Out-of-Distribution Detection. diff --git a/auto_tutorial_source/Classification/tutorial_vit.py b/auto_tutorial_source/Classification/tutorial_vit.py new file mode 100644 index 00000000..e42c9e46 --- /dev/null +++ b/auto_tutorial_source/Classification/tutorial_vit.py @@ -0,0 +1,245 @@ +""" +ViT baslines with torch-uncertainty on imagenet1k +=============================================== + +This tutorial is about using torch-uncertainty to benchmark a ViT model on imagenet1k with various robustness metricis +and apply easily a postprocess step on top either of the single model or deep ensemble. + +Dataset +------- + +In this tutorial we will use imagenet1k dataset available directly through torch uncertainty a long with various far/near ood datasets +also handled automatically by torch-uncertainty. + + +1. Load the a single vit model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +""" + +# %% +import torch +import torch.nn as nn +from torchvision.models import vit_b_16 +from huggingface_hub import hf_hub_download + + +def load_model_from_hf(repo_id: str, filename: str, device: str = "cpu", revision: str = "main"): + ckpt_path = hf_hub_download( + repo_id=repo_id, filename=filename, repo_type="model", revision=revision + ) + ckpt = torch.load(ckpt_path, map_location="cpu") + state = ckpt.get("state_dict", ckpt) + + new_state = {} + for k, v in state.items(): + name = k[len("model.") :] if k.startswith("model.") else k + new_state[name] = v + + renamed = {} + for k, v in new_state.items(): + if k == "heads.weight": + renamed["heads.head.weight"] = v + elif k == "heads.bias": + renamed["heads.head.bias"] = v + else: + renamed[k] = v + + model = vit_b_16(weights=None, num_classes=1000, image_size=224) + model.load_state_dict(renamed, strict=True) + + model.eval().to(device) + return model + + +model1 = load_model_from_hf( + repo_id="torch-uncertainty/vit-b-16-im1k", + filename="model1.ckpt", + device="cpu", +) + + +# %% +# 2. Benchmark the single model +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# We define first the imagnet1k datamodule then run the classification routine as follows. + +from torch_uncertainty.routines import ClassificationRoutine +from torch_uncertainty import TUTrainer +from torch_uncertainty.datamodules import ImageNetDataModule + + +path = "./data" + +dm = ImageNetDataModule( + root=path, + batch_size=512, + num_workers=4, + pin_memory=True, + interpolation="bicubic", + eval_ood=True, +) + +trainer = TUTrainer(accelerator="gpu", enable_progress_bar=True, devices=1) + +routine = ClassificationRoutine( + num_classes=1000, + model=model1, + loss=nn.CrossEntropyLoss(), + eval_ood=True, +) +res = trainer.test(routine, datamodule=dm) + + +# %% +# 3. Apply a postprocess step on top of the single model +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Here we will be applying for example a temperature scaling postprocess step +# on top of the single model but torch-uncertainty supports many other postprocess please refer to the documentation. + +from torch_uncertainty.post_processing import TemperatureScaler + +dm.setup("fit") + +scaler1 = TemperatureScaler(model=model1, device="cuda") +scaler1.cuda() +scaler1.fit(dataloader=dm.postprocess_dataloader()) +print(scaler1.temperature[0]) + + +routine = ClassificationRoutine( + num_classes=1000, + model=scaler1, + loss=nn.CrossEntropyLoss(), + eval_ood=True, +) +res = trainer.test(routine, datamodule=dm) + + +# %% +# 4. Load and benchmark a deep ensemble of ViT models +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Let us load the remaining models of the deep ensemble and then benchmark them easily with torch unceratinty. + +model2 = load_model_from_hf( + repo_id="torch-uncertainty/vit-b-16-im1k", + filename="model2.ckpt", + device="cpu", +) + +model3 = load_model_from_hf( + repo_id="torch-uncertainty/vit-b-16-im1k", + filename="model3.ckpt", + device="cpu", +) + + +from torch_uncertainty.models import deep_ensembles + +deep = deep_ensembles([model1, model2, model3]) + + +routine = ClassificationRoutine( + num_classes=1000, + model=deep, + loss=nn.CrossEntropyLoss(), + eval_ood=True, +) +res = trainer.test(routine, datamodule=dm) + +# %% +# Next, let us also apply a temperature scaling postprocess step on top of the deep ensemble. + +dm.setup("fit") + +scaler2 = TemperatureScaler(model=deep, device="cuda") +scaler2.cuda() +scaler2.fit(dataloader=dm.postprocess_dataloader()) +print(scaler2.temperature[0]) + +routine = ClassificationRoutine( + num_classes=1000, + model=scaler2, + loss=nn.CrossEntropyLoss(), + eval_ood=True, +) +res = trainer.test(routine, datamodule=dm) + +# %% +# 5. Load and benchmark packed ensemble of ViT model +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Let us load the packed ensemble vit and benchmark it with torch unceratinty. + +from torch_uncertainty.models.classification.vit import PackedVit + + +def load_packedvit_from_hf( + repo_id: str, + filename: str, + device: str = "cpu", + revision: str = "main", +): + ckpt_path = hf_hub_download( + repo_id=repo_id, + filename=filename, + repo_type="model", + revision=revision, + ) + + ckpt = torch.load(ckpt_path, map_location="cpu") + state = ckpt.get("state_dict", ckpt) + + clean_state = { + k.replace("model.", "").replace("routine.model.", ""): v for k, v in state.items() + } + + model = PackedVit( + image_size=224, + patch_size=16, + num_layers=12, + num_heads=12, + hidden_dim=774, + mlp_dim=3072, + num_classes=1000, + num_estimators=3, + alpha=2, + ) + + model.load_state_dict(clean_state, strict=True) + model.eval().to(device) + + return model + + +packed = load_packedvit_from_hf( + repo_id="torch-uncertainty/vit-b-16-im1k", + filename="packed.ckpt", + device="cpu", +) + +routine = ClassificationRoutine( + num_classes=1000, + model=packed, + loss=nn.CrossEntropyLoss(), + eval_ood=True, +) +res = trainer.test(routine, datamodule=dm) + + +# %% +# Next, let us also apply a temperature scaling postprocess step on top of the packedvit. + +dm.setup("fit") + +scaler3 = TemperatureScaler(model=packed, device="cuda") +scaler3.cuda() +scaler3.fit(dataloader=dm.postprocess_dataloader()) +print(scaler3.temperature[0]) + +routine = ClassificationRoutine( + num_classes=1000, + model=scaler3, + loss=nn.CrossEntropyLoss(), + eval_ood=True, +) +res = trainer.test(routine, datamodule=dm) diff --git a/docs/source/api.rst b/docs/source/api.rst index fcb61256..ec12799a 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -430,7 +430,7 @@ Scaling Methods OOD Scores ----------------------- -.. currentmodule:: torch_uncertainty.ood_criteria +.. currentmodule:: torch_uncertainty.ood.ood_criteria .. autosummary:: :toctree: generated/ diff --git a/pyproject.toml b/pyproject.toml index 696df0f5..b0303fa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,13 @@ dependencies = [ "torchvision>=0.16", "einops", "seaborn", + "scikit-learn", + "scipy", + "faiss-cpu", + "statsmodels", + "transformers", + "pyarrow==20.0.0", + "datasets", ] [project.optional-dependencies] diff --git a/tests/_dummies/__init__.py b/tests/_dummies/__init__.py index 4f2df70d..b06ac0e8 100644 --- a/tests/_dummies/__init__.py +++ b/tests/_dummies/__init__.py @@ -17,5 +17,5 @@ DummyRegressionDataset, DummySegmentationDataset, ) -from .model import dummy_model +from .model import dummy_model, dummy_ood_model from .transform import DummyTransform diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 8e801274..e235d043 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -3,7 +3,7 @@ from torch import nn from torch_uncertainty.models import EMA, SWA, deep_ensembles -from torch_uncertainty.ood_criteria import TUOODCriterion +from torch_uncertainty.ood.ood_criteria import TUOODCriterion from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.routines import ( diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 3dad261a..05618037 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -18,9 +18,9 @@ class DummyClassificationDataModule(TUDataModule): - num_channels = 1 + num_channels: int = 1 image_size: int = 4 - training_task = "classification" + training_task: str = "classification" def __init__( self, @@ -34,6 +34,8 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = True, num_images: int = 2, + near_ood_datasets: list | None = None, + far_ood_datasets: list | None = None, ) -> None: super().__init__( root=root, @@ -51,17 +53,25 @@ def __init__( self.num_classes = num_classes self.num_images = num_images + # Dataset classes self.dataset = DummyClassificationDataset self.ood_dataset = DummyClassificationDataset self.shift_dataset = DummyClassificationDataset + # Custom near/far OOD dataset classes + self.near_ood_datasets = near_ood_datasets or [] + self.far_ood_datasets = far_ood_datasets or [] + + # Simple tensor transforms self.train_transform = v2.ToTensor() self.test_transform = v2.ToTensor() def prepare_data(self) -> None: + # No external data to download for dummy pass def setup(self, stage: str | None = None) -> None: + # Training / validation setup if stage == "fit" or stage is None: self.train = self.dataset( self.root, @@ -79,7 +89,10 @@ def setup(self, stage: str | None = None) -> None: transform=self.test_transform, num_images=self.num_images, ) - elif stage == "test": + + # Test / OOD / shift setup + if stage == "test" or stage is None: + # Main test set self.test = self.dataset( self.root, num_channels=self.num_channels, @@ -88,33 +101,100 @@ def setup(self, stage: str | None = None) -> None: transform=self.test_transform, num_images=self.num_images, ) - self.ood = self.ood_dataset( - self.root, - num_channels=self.num_channels, - num_classes=self.num_classes, - image_size=self.image_size, - transform=self.test_transform, - num_images=self.num_images, - ) - self.shift = self.shift_dataset( - self.root, - num_channels=self.num_channels, - num_classes=self.num_classes, - image_size=self.image_size, - transform=self.test_transform, - num_images=self.num_images, - ) - self.shift.shift_severity = 1 - def test_dataloader(self) -> DataLoader | list[DataLoader]: - dataloader = [self._data_loader(self.test, training=False, shuffle=False)] + if self.eval_ood: + # Validation OOD (equivalent to val_ood) + self.val_ood = self.ood_dataset( + self.root, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transform=self.test_transform, + num_images=self.num_images, + ) + # Near OOD + if self.near_ood_datasets: + self.near_oods = [ + ds( + self.root, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transform=self.test_transform, + num_images=self.num_images, + ) + for ds in self.near_ood_datasets + ] + else: + # default single near OOD + self.near_oods = [ + self.ood_dataset( + self.root, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transform=self.test_transform, + num_images=self.num_images, + ) + ] + + # Far OOD + if self.far_ood_datasets: + self.far_oods = [ + ds( + self.root, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transform=self.test_transform, + num_images=self.num_images, + ) + for ds in self.far_ood_datasets + ] + else: + # default single far OOD + self.far_oods = [ + self.ood_dataset( + self.root, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transform=self.test_transform, + num_images=self.num_images, + ) + ] + + # Shifted dataset + if self.eval_shift: + self.shift = self.shift_dataset( + self.root, + num_channels=self.num_channels, + num_classes=self.num_classes, + image_size=self.image_size, + transform=self.test_transform, + num_images=self.num_images, + ) + self.shift.shift_severity = 1 + + def train_dataloader(self) -> DataLoader: + return self._data_loader(self.train, training=True, shuffle=True) + + def val_dataloader(self) -> DataLoader: + return self._data_loader(self.val, training=False) + + def test_dataloader(self) -> list[DataLoader]: + loaders = [self._data_loader(self.test, training=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.get_ood_set(), training=False, shuffle=False)) - if self.eval_shift: - dataloader.append( - self._data_loader(self.get_shift_set(), training=False, shuffle=False) + loaders.append(self._data_loader(self.val_ood, training=False, shuffle=False)) + loaders.extend( + self._data_loader(ds, training=False, shuffle=False) for ds in self.near_oods + ) + loaders.extend( + self._data_loader(ds, training=False, shuffle=False) for ds in self.far_oods ) - return dataloader + if self.eval_shift: + loaders.append(self._data_loader(self.shift, training=False, shuffle=False)) + return loaders def _get_train_data(self) -> ArrayLike: return self.train.data @@ -122,6 +202,33 @@ def _get_train_data(self) -> ArrayLike: def _get_train_targets(self) -> ArrayLike: return np.array(self.train.targets) + def get_indices(self) -> dict[str, list[int]]: + idx = 0 + indices: dict[str, list[int]] = {} + # Main test + indices["test"] = [idx] + idx += 1 + # OOD + if self.eval_ood: + indices["val_ood"] = [idx] + idx += 1 + n_near = len(self.near_oods) + indices["near_oods"] = list(range(idx, idx + n_near)) + idx += n_near + n_far = len(self.far_oods) + indices["far_oods"] = list(range(idx, idx + n_far)) + idx += n_far + else: + indices["val_ood"] = [] + indices["near_oods"] = [] + indices["far_oods"] = [] + # Shift + if self.eval_shift: + indices["shift"] = [idx] + else: + indices["shift"] = [] + return indices + class DummyRegressionDataModule(TUDataModule): in_features = 4 diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 0ef4ba33..628c9771 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -44,6 +44,7 @@ def __init__( self.train = train # training set or test set self.transform = transform self.target_transform = target_transform + self.dataset_name = "dummy" self.data: Any = [] self.targets = [] diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index 45bcae38..57884cd2 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -162,3 +162,44 @@ def dummy_segmentation_model( image_size=image_size, dist_family=dist_family, ) + + +class dummy_ood_model(nn.Module): # noqa: N801 + def __init__(self, in_channels=3, feat_dim=4096, num_classes=3): + super().__init__() + self.feat = nn.Sequential( + nn.Conv2d(in_channels, 8, 3, 1, 1), + nn.ReLU(), + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + nn.Linear(8, feat_dim), + nn.ReLU(), + ) + self.norm = nn.LayerNorm(feat_dim) + self.fc = nn.Linear(feat_dim, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.uniform_(m.weight, -0.02, 0.02) + nn.init.zeros_(m.bias) + if isinstance(m, nn.Conv2d): + nn.init.uniform_(m.weight, -0.02, 0.02) + nn.init.zeros_(m.bias) + + self.feature_size = feat_dim + + def forward(self, x, return_feature=False, return_feature_list=False): + f = self.feat(x) + f = torch.tanh(self.norm(f)) + logits = self.fc(f) + if return_feature: + return logits, f + return logits + + def get_fc(self): + w = self.fc.weight.detach().cpu().numpy() + b = self.fc.bias.detach().cpu().numpy() + return w, b + + def get_fc_layer(self): + return self.fc diff --git a/tests/datamodules/classification/test_cifar10.py b/tests/datamodules/classification/test_cifar10.py index 897ae91d..014c66b1 100644 --- a/tests/datamodules/classification/test_cifar10.py +++ b/tests/datamodules/classification/test_cifar10.py @@ -32,7 +32,6 @@ def test_cifar10_main(self) -> None: assert isinstance(dm.train_transform.transforms[2], Cutout) dm.dataset = DummyClassificationDataset - dm.ood_dataset = DummyClassificationDataset dm.shift_dataset = DummyClassificationDataset dm.prepare_data() @@ -50,7 +49,6 @@ def test_cifar10_main(self) -> None: dm.val_dataloader() dm.test_dataloader() - dm.eval_ood = True dm.eval_shift = True dm.prepare_data() dm.setup("test") @@ -74,7 +72,6 @@ def test_cifar10_main(self) -> None: randaugment=True, ) dm.dataset = DummyClassificationDataset - dm.ood_dataset = DummyClassificationDataset dm.setup() dm.train_dataloader() @@ -84,20 +81,19 @@ def test_cifar10_main(self) -> None: num_dataloaders=1, val_split=0.1, num_tta=64, - eval_ood=True, eval_shift=True, ) dm.dataset = DummyClassificationDataset - dm.ood_dataset = DummyClassificationDataset dm.shift_dataset = DummyClassificationDataset dm.setup() dm.get_val_set() dm.get_test_set() - dm.get_ood_set() dm.get_shift_set() + idx = dm.get_indices() + assert idx["shift"] == [1] with pytest.raises(ValueError): - dm = CIFAR10DataModule( + CIFAR10DataModule( root="./data/", batch_size=128, cutout=8, @@ -115,13 +111,13 @@ def test_cifar10_main(self) -> None: dm.setup("fit") with pytest.raises(ValueError, match="Test set "): - dm = CIFAR10DataModule( + CIFAR10DataModule( root="./data/", batch_size=128, test_alt="x", ) - dm = CIFAR10DataModule( + CIFAR10DataModule( root="./data/", batch_size=128, num_dataloaders=2, @@ -150,3 +146,162 @@ def test_cifar10_cv(self) -> None: num_images=20, ) dm.make_cross_val_splits(2, 1) + + def test_ood_defaults_and_get_indices(self, monkeypatch) -> None: + dm = CIFAR10DataModule( + root="./data/", + batch_size=16, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + eval_ood=True, + ) + + dm.dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset + + def _mock_get_ood(**_): + test_ood = DummyClassificationDataset( + root="./data/", train=False, download=False, transform=nn.Identity(), num_images=5 + ) + val_ood = DummyClassificationDataset( + root="./data/", train=False, download=False, transform=nn.Identity(), num_images=5 + ) + near_default = { + "example1": DummyClassificationDataset( + root="./data/", + train=False, + download=False, + transform=nn.Identity(), + num_images=5, + ), + "example2": DummyClassificationDataset( + root="./data/", + train=False, + download=False, + transform=nn.Identity(), + num_images=5, + ), + } + far_default = { + "example3": DummyClassificationDataset( + root="./data/", + train=False, + download=False, + transform=nn.Identity(), + num_images=5, + ), + "example4": DummyClassificationDataset( + root="./data/", + train=False, + download=False, + transform=nn.Identity(), + num_images=5, + ), + "example5": DummyClassificationDataset( + root="./data/", + train=False, + download=False, + transform=nn.Identity(), + num_images=5, + ), + } + return test_ood, val_ood, near_default, far_default + + monkeypatch.setattr( + "torch_uncertainty.datamodules.classification.cifar10.get_ood_datasets", + _mock_get_ood, + ) + dm.setup("test") + + assert hasattr(dm, "near_oods") + assert len(dm.near_oods) == 2 + assert hasattr(dm, "far_oods") + assert len(dm.far_oods) == 3 + + for ds in [dm.val_ood, *dm.near_oods, *dm.far_oods]: + assert hasattr(ds, "dataset_name") + assert ds.dataset_name in {"dummy", ds.__class__.__name__.lower()} + + assert dm.near_ood_names == [ds.dataset_name for ds in dm.near_oods] + assert dm.far_ood_names == [ds.dataset_name for ds in dm.far_oods] + + loaders = dm.test_dataloader() + expected = 1 + 1 + 1 + len(dm.near_oods) + len(dm.far_oods) + assert len(loaders) == expected + + idx = dm.get_indices() + assert idx["test"] == [0] + assert idx["test_ood"] == [1] + assert idx["val_ood"] == [2] + assert idx["near_oods"] == list(range(3, 3 + len(dm.near_oods))) + assert idx["far_oods"] == list( + range(3 + len(dm.near_oods), 3 + len(dm.near_oods) + len(dm.far_oods)) + ) + assert idx["shift"] == [] + + def test_user_supplied_near_far_ood_typecheck_and_override(self, monkeypatch, tmp_path): + """Covers user-provided OOD lists (type checks and overrides).""" + # invalid near + dm_bad_near = CIFAR10DataModule( + root=tmp_path, + batch_size=16, + eval_ood=True, + near_ood_datasets=[123], + ) + dm_bad_near.dataset = DummyClassificationDataset + dm_bad_near.shift_dataset = DummyClassificationDataset + with pytest.raises(TypeError, match="near_ood_datasets must be Dataset objects"): + dm_bad_near.setup("test") + + # invalid far + dm_bad_far = CIFAR10DataModule( + root=tmp_path, + batch_size=16, + eval_ood=True, + far_ood_datasets=[object()], + ) + dm_bad_far.dataset = DummyClassificationDataset + dm_bad_far.shift_dataset = DummyClassificationDataset + with pytest.raises(TypeError, match="far_ood_datasets must be Dataset objects"): + dm_bad_far.setup("test") + + # valid override + near_custom = [DummyClassificationDataset(root="./data/", num_images=2)] + far_custom = [DummyClassificationDataset(root="./data/", num_images=1)] + + dm = CIFAR10DataModule( + root=tmp_path, + batch_size=16, + eval_ood=True, + near_ood_datasets=near_custom, + far_ood_datasets=far_custom, + ) + dm.dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset + + def _fake_get_ood(**_): + return ( + DummyClassificationDataset(root="./data/", num_images=1), + DummyClassificationDataset(root="./data/", num_images=1), + {"near_default": DummyClassificationDataset(root="./data/", num_images=1)}, + {"far_default": DummyClassificationDataset(root="./data/", num_images=1)}, + ) + + monkeypatch.setattr( + "torch_uncertainty.datamodules.classification.cifar10.get_ood_datasets", + _fake_get_ood, + ) + dm.setup("test") + + assert dm.near_oods is near_custom + assert dm.far_oods is far_custom + for ds in [dm.val_ood, *dm.near_oods, *dm.far_oods]: + assert hasattr(ds, "dataset_name") + + idx = dm.get_indices() + assert idx["test"] == [0] + assert idx["test_ood"] == [1] + assert idx["val_ood"] == [2] + assert idx["near_oods"] == [3] + assert idx["far_oods"] == [4] + assert idx["shift"] == [] diff --git a/tests/datamodules/classification/test_cifar100.py b/tests/datamodules/classification/test_cifar100.py index dcfe1f99..56ed0a02 100644 --- a/tests/datamodules/classification/test_cifar100.py +++ b/tests/datamodules/classification/test_cifar100.py @@ -33,7 +33,6 @@ def test_cifar100(self) -> None: assert isinstance(dm.train_transform.transforms[2], Cutout) dm.dataset = DummyClassificationDataset - dm.ood_dataset = DummyClassificationDataset dm.shift_dataset = DummyClassificationDataset dm.prepare_data() @@ -43,7 +42,6 @@ def test_cifar100(self) -> None: dm.val_dataloader() dm.test_dataloader() - dm.eval_ood = True dm.eval_shift = True dm.prepare_data() dm.setup("test") @@ -58,7 +56,6 @@ def test_cifar100(self) -> None: basic_augment=False, ) dm.dataset = DummyClassificationDataset - dm.ood_dataset = DummyClassificationDataset dm.shift_dataset = DummyClassificationDataset dm.setup() dm.setup("test") @@ -67,7 +64,7 @@ def test_cifar100(self) -> None: dm.setup("other") with pytest.raises(ValueError): - dm = CIFAR100DataModule( + CIFAR100DataModule( root="./data/", batch_size=128, num_dataloaders=1, @@ -75,9 +72,8 @@ def test_cifar100(self) -> None: randaugment=True, ) - dm = CIFAR100DataModule(root="./data/", batch_size=128, randaugment=True) - - dm = CIFAR100DataModule(root="./data/", batch_size=128, auto_augment="rand-m9-n2-mstd0.5") + CIFAR100DataModule(root="./data/", batch_size=128, randaugment=True) + CIFAR100DataModule(root="./data/", batch_size=128, auto_augment="rand-m9-n2-mstd0.5") def test_cifar100_cv(self) -> None: dm = CIFAR100DataModule(root="./data/", batch_size=128) @@ -99,3 +95,232 @@ def test_cifar100_cv(self) -> None: num_images=20, ) dm.make_cross_val_splits(2, 1) + + def test_ood_defaults_and_get_indices(self, monkeypatch) -> None: + dm = CIFAR100DataModule( + root="./data/", + batch_size=16, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + eval_ood=True, + eval_shift=True, + ) + + dm.dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset + + def _mock_get_ood(**_): + test_ood = DummyClassificationDataset( + root="./data/", train=False, download=False, transform=nn.Identity(), num_images=5 + ) + val_ood = DummyClassificationDataset( + root="./data/", train=False, download=False, transform=nn.Identity(), num_images=5 + ) + near_default = { + "example1": DummyClassificationDataset( + root="./data/", + train=False, + download=False, + transform=nn.Identity(), + num_images=5, + ), + "example2": DummyClassificationDataset( + root="./data/", + train=False, + download=False, + transform=nn.Identity(), + num_images=5, + ), + } + far_default = { + "example3": DummyClassificationDataset( + root="./data/", + train=False, + download=False, + transform=nn.Identity(), + num_images=5, + ), + "example4": DummyClassificationDataset( + root="./data/", + train=False, + download=False, + transform=nn.Identity(), + num_images=5, + ), + "example5": DummyClassificationDataset( + root="./data/", + train=False, + download=False, + transform=nn.Identity(), + num_images=5, + ), + } + return test_ood, val_ood, near_default, far_default + + monkeypatch.setattr( + "torch_uncertainty.datamodules.classification.cifar100.get_ood_datasets", + _mock_get_ood, + ) + dm.setup("test") + + assert hasattr(dm, "near_oods") + assert len(dm.near_oods) == 2 + + assert hasattr(dm, "far_oods") + assert len(dm.far_oods) == 3 + + for ds in [dm.val_ood, *dm.near_oods, *dm.far_oods]: + assert hasattr(ds, "dataset_name") + assert ds.dataset_name in {"dummy", ds.__class__.__name__.lower()} + + assert dm.near_ood_names == [ds.dataset_name for ds in dm.near_oods] + assert dm.far_ood_names == [ds.dataset_name for ds in dm.far_oods] + + loaders = dm.test_dataloader() + expected = 1 + 1 + 1 + len(dm.near_oods) + len(dm.far_oods) + 1 + assert len(loaders) == expected + + idx = dm.get_indices() + assert idx["test"] == [0] + assert idx["test_ood"] == [1] + assert idx["val_ood"] == [2] + assert idx["near_oods"] == list(range(3, 3 + len(dm.near_oods))) + assert idx["far_oods"] == list( + range(3 + len(dm.near_oods), 3 + len(dm.near_oods) + len(dm.far_oods)) + ) + assert idx["shift"] == [3 + len(dm.near_oods) + len(dm.far_oods)] + + def test_user_supplied_near_far_ood_typecheck_and_override(self, monkeypatch, tmp_path): + """Check that custom OOD datasets override defaults and type errors are raised.""" + dm_bad_near = CIFAR100DataModule( + root=tmp_path, + batch_size=16, + eval_ood=True, + near_ood_datasets=[123, "bad"], # invalid + ) + dm_bad_near.dataset = DummyClassificationDataset + dm_bad_near.shift_dataset = DummyClassificationDataset + with pytest.raises(TypeError, match="near_ood_datasets must be Dataset objects"): + dm_bad_near.setup("test") + + dm_bad_far = CIFAR100DataModule( + root=tmp_path, + batch_size=16, + eval_ood=True, + far_ood_datasets=[object()], # invalid + ) + dm_bad_far.dataset = DummyClassificationDataset + dm_bad_far.shift_dataset = DummyClassificationDataset + with pytest.raises(TypeError, match="far_ood_datasets must be Dataset objects"): + dm_bad_far.setup("test") + + near_custom = [DummyClassificationDataset(root="./data/", num_images=2)] + far_custom = [DummyClassificationDataset(root="./data/", num_images=1)] + + dm = CIFAR100DataModule( + root=tmp_path, + batch_size=16, + eval_ood=True, + near_ood_datasets=near_custom, + far_ood_datasets=far_custom, + ) + dm.dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset + + def _fake_get_ood(**_): + return ( + DummyClassificationDataset(root="./data/", num_images=1), + DummyClassificationDataset(root="./data/", num_images=1), + {"near_default": DummyClassificationDataset(root="./data/", num_images=1)}, + {"far_default": DummyClassificationDataset(root="./data/", num_images=1)}, + ) + + monkeypatch.setattr( + "torch_uncertainty.datamodules.classification.cifar100.get_ood_datasets", + _fake_get_ood, + ) + + dm.setup("test") + + assert dm.near_oods is near_custom + assert dm.far_oods is far_custom + for ds in [dm.val_ood, *dm.near_oods, *dm.far_oods]: + assert hasattr(ds, "dataset_name") + + idx = dm.get_indices() + assert idx["test"] == [0] + assert idx["test_ood"] == [1] + assert idx["val_ood"] == [2] + assert idx["near_oods"] == [3] + assert idx["far_oods"] == [4] + assert idx["shift"] == [] + + def test_assigns_dataset_name_when_missing(self, monkeypatch): + """If OOD datasets lack `dataset_name`, setup() should assign class-name.lower().""" + dm = CIFAR100DataModule( + root="./data/", + batch_size=16, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + eval_ood=True, + ) + dm.dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset + + class _NoNameDS: + def __init__( + self, root="./data/", train=False, download=False, transform=None, num_images=3 + ): + self.data = list(range(num_images)) + self.transform = transform + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + x = self.data[i] + return (x if self.transform is None else self.transform(x)), 0 + + def _mock_get_ood(**_): + test_ood = _NoNameDS(num_images=3) + val_ood = _NoNameDS(num_images=4) + near_default = {"nearA": _NoNameDS(num_images=5)} + far_default = {"farB": _NoNameDS(num_images=6)} + return test_ood, val_ood, near_default, far_default + + monkeypatch.setattr( + "torch_uncertainty.datamodules.classification.cifar100.get_ood_datasets", + _mock_get_ood, + ) + + dm.setup("test") + + assert hasattr(dm.val_ood, "dataset_name") + assert dm.val_ood.dataset_name == "_nonameds" + for ds in dm.near_oods + dm.far_oods: + assert hasattr(ds, "dataset_name") + assert ds.dataset_name == "_nonameds" + + assert dm.near_ood_names == [ds.dataset_name for ds in dm.near_oods] + assert dm.far_ood_names == [ds.dataset_name for ds in dm.far_oods] + + def test_get_indices_empty_when_eval_ood_false(self): + dm = CIFAR100DataModule( + root="./data/", + batch_size=16, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + eval_ood=False, + eval_shift=False, + ) + dm.dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset + dm.setup("test") + + idx = dm.get_indices() + assert idx["test"] == [0] + assert idx["test_ood"] == [] + assert idx["val_ood"] == [] + assert idx["near_oods"] == [] + assert idx["far_oods"] == [] + assert idx["shift"] == [] diff --git a/tests/datamodules/classification/test_imagenet.py b/tests/datamodules/classification/test_imagenet.py index 0cf86e65..6ea935da 100644 --- a/tests/datamodules/classification/test_imagenet.py +++ b/tests/datamodules/classification/test_imagenet.py @@ -1,8 +1,9 @@ from pathlib import Path import pytest +import torch from torch import nn -from torchvision.datasets import ImageNet +from torch.utils.data import Dataset from tests._dummies.dataset import DummyClassificationDataset from torch_uncertainty.datamodules import ImageNetDataModule @@ -11,7 +12,62 @@ class TestImageNetDataModule: """Testing the ImageNetDataModule datamodule class.""" - def test_imagenet(self) -> None: + class _TinyImgDataset(Dataset): + def __init__(self, n=3): + self.n = n + + def __len__(self): + return self.n + + def __getitem__(self, idx): + return torch.zeros(3, 224, 224), 0 + + class _DummyFileListDataset(Dataset): + def __init__(self, root, list_file, transform): + self.ds = TestImageNetDataModule._TinyImgDataset() + + def __len__(self): + return len(self.ds) + + def __getitem__(self, i): + return self.ds[i] + + class _DummyImageFolder(Dataset): + def __init__(self, root, transform=None): + self.ds = TestImageNetDataModule._TinyImgDataset() + self.transform = transform + + def __len__(self): + return len(self.ds) + + def __getitem__(self, i): + x, y = self.ds[i] + return (self.transform(x) if self.transform is not None else x), y + + @staticmethod + def _fake_download_and_extract(_name, dest_root): + return str(dest_root) + + @staticmethod + def _fake_download_and_extract_splits_from_hf(root): + return Path(root) + + @staticmethod + def _fake_get_ood_datasets(**_): + test_ood = DummyClassificationDataset(root="./data/", num_images=2) + val_ood = DummyClassificationDataset(root="./data/", num_images=2) + near_default = { + "near1": DummyClassificationDataset(root="./data/", num_images=1), + "near2": DummyClassificationDataset(root="./data/", num_images=1), + } + far_default = { + "far1": DummyClassificationDataset(root="./data/", num_images=1), + "far2": DummyClassificationDataset(root="./data/", num_images=1), + "far3": DummyClassificationDataset(root="./data/", num_images=1), + } + return test_ood, val_ood, near_default, far_default + + def test_imagenet(self, monkeypatch) -> None: dm = ImageNetDataModule( root="./data/", batch_size=128, @@ -29,30 +85,36 @@ def test_imagenet(self) -> None: ) assert isinstance(dm.train_transform, nn.Identity) assert isinstance(dm.test_transform, nn.Identity) - assert dm.dataset == ImageNet dm.dataset = DummyClassificationDataset - dm.ood_dataset = DummyClassificationDataset + + fake_root = "./data/" + mod_name = ImageNetDataModule.__module__ + + def _fake_download_and_extract(name, dest_root): # noqa: ARG001 + return str(fake_root) + + monkeypatch.setattr( + f"{mod_name}.download_and_extract_hf_dataset", + _fake_download_and_extract, + raising=True, + ) + dm.prepare_data() dm.setup() path = Path(__file__).parent.resolve() / "../../assets/dummy_indices.yaml" dm = ImageNetDataModule(root="./data/", batch_size=128, val_split=path) - dm.dataset = DummyClassificationDataset - dm.ood_dataset = DummyClassificationDataset dm.shift_dataset = DummyClassificationDataset dm.setup("fit") dm.setup("test") - dm.train_dataloader() dm.val_dataloader() dm.test_dataloader() dm.val_split = None dm.setup("fit") - dm.train_dataloader() dm.val_dataloader() dm.test_dataloader() - dm.eval_ood = True dm.eval_shift = True dm.prepare_data() dm.setup("test") @@ -72,30 +134,13 @@ def test_imagenet(self) -> None: for test_alt in ["r", "o", "a"]: dm = ImageNetDataModule(root="./data/", batch_size=128, test_alt=test_alt) - with pytest.raises(ValueError): - dm.setup() - with pytest.raises(ValueError): dm = ImageNetDataModule(root="./data/", batch_size=128, test_alt="x") - for ood_ds in ["inaturalist", "imagenet-o", "textures", "openimage-o"]: - dm = ImageNetDataModule(root="./data/", batch_size=128, ood_ds=ood_ds) - if ood_ds == "inaturalist": - dm.eval_ood = True - dm.dataset = DummyClassificationDataset - dm.ood_dataset = DummyClassificationDataset - dm.prepare_data() - dm.setup("test") - dm.test_dataloader() - - with pytest.raises(ValueError): - dm = ImageNetDataModule(root="./data/", batch_size=128, ood_ds="other") - for procedure in ["ViT", "A3"]: dm = ImageNetDataModule( root="./data/", batch_size=128, - ood_ds="svhn", procedure=procedure, ) @@ -108,3 +153,346 @@ def test_imagenet(self) -> None: dm.root = Path("./tests/testlog") with pytest.raises(FileNotFoundError): dm._verify_splits(split="test") + + def test_ood_defaults_and_get_indices(self, monkeypatch) -> None: + dm = ImageNetDataModule( + root="./data/", + batch_size=16, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + eval_ood=True, + eval_shift=True, + ) + + fake_root = "./data/" + mod_name = ImageNetDataModule.__module__ + + def _fake_download_and_extract(name, dest_root): # noqa: ARG001 + return str(fake_root) + + monkeypatch.setattr( + f"{mod_name}.download_and_extract_hf_dataset", + _fake_download_and_extract, + raising=True, + ) + + dm.dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset + + monkeypatch.setattr( + "torch_uncertainty.datamodules.classification.imagenet.get_ood_datasets", + self._fake_get_ood_datasets, + ) + dm.setup("test") + + assert hasattr(dm, "near_oods") + assert len(dm.near_oods) == 2 + + assert hasattr(dm, "far_oods") + assert len(dm.far_oods) == 3 + + for ds in [dm.val_ood, *dm.near_oods, *dm.far_oods]: + assert hasattr(ds, "dataset_name") + + assert dm.near_ood_names == [ds.dataset_name for ds in dm.near_oods] + assert dm.far_ood_names == [ds.dataset_name for ds in dm.far_oods] + + loaders = dm.test_dataloader() + expected = 1 + 1 + 1 + len(dm.near_oods) + len(dm.far_oods) + 1 + assert len(loaders) == expected + + idx = dm.get_indices() + assert idx["test"] == [0] + assert idx["test_ood"] == [1] + assert idx["val_ood"] == [2] + assert idx["near_oods"] == list(range(3, 3 + len(dm.near_oods))) + assert idx["far_oods"] == list( + range(3 + len(dm.near_oods), 3 + len(dm.near_oods) + len(dm.far_oods)) + ) + assert idx["shift"] == [3 + len(dm.near_oods) + len(dm.far_oods)] + + def test_setup_fit_rejects_test_alt(self, monkeypatch, tmp_path): + """setup('fit') must raise when test_alt is provided.""" + mod_name = ImageNetDataModule.__module__ + monkeypatch.setattr( + f"{mod_name}.download_and_extract_hf_dataset", + self._fake_download_and_extract, + raising=True, + ) + monkeypatch.setattr( + f"{mod_name}.download_and_extract_splits_from_hf", + self._fake_download_and_extract_splits_from_hf, + raising=True, + ) + monkeypatch.setattr(f"{mod_name}.ImageNetR", DummyClassificationDataset, raising=True) + + dm = ImageNetDataModule( + root=tmp_path, + batch_size=8, + test_alt="r", + basic_augment=False, + num_workers=0, + persistent_workers=False, + pin_memory=False, + ) + with pytest.raises(ValueError, match="test_alt.*not supported for training"): + dm.setup("fit") + + def test_setup_test_with_test_alt(self, monkeypatch, tmp_path): + """setup('test') with test_alt uses alt dataset constructor.""" + mod_name = ImageNetDataModule.__module__ + monkeypatch.setattr(f"{mod_name}.ImageNetR", DummyClassificationDataset, raising=True) + monkeypatch.setattr( + f"{mod_name}.download_and_extract_hf_dataset", + self._fake_download_and_extract, + raising=True, + ) + monkeypatch.setattr( + f"{mod_name}.download_and_extract_splits_from_hf", + self._fake_download_and_extract_splits_from_hf, + raising=True, + ) + + dm = ImageNetDataModule( + root=tmp_path, + batch_size=8, + test_alt="r", + basic_augment=False, + num_workers=0, + persistent_workers=False, + pin_memory=False, + ) + dm.setup("test") + assert isinstance(dm.test, DummyClassificationDataset) + + def test_guard_test_alt_with_eval_ood(self): + with pytest.raises(ValueError, match="test_alt.*not supported.*ood_eval"): + ImageNetDataModule( + root="./data/", + batch_size=8, + test_alt="r", + eval_ood=True, + ) + + def test_near_far_instances_used_and_named(self, monkeypatch, tmp_path): + mod_name = ImageNetDataModule.__module__ + monkeypatch.setattr( + f"{mod_name}.download_and_extract_hf_dataset", + self._fake_download_and_extract, + raising=True, + ) + monkeypatch.setattr( + f"{mod_name}.download_and_extract_splits_from_hf", + self._fake_download_and_extract_splits_from_hf, + raising=True, + ) + monkeypatch.setattr( + f"{mod_name}.FileListDataset", + self._DummyFileListDataset, + raising=True, + ) + monkeypatch.setattr( + "torch_uncertainty.datamodules.classification.imagenet.get_ood_datasets", + self._fake_get_ood_datasets, + ) + + class NearDS(Dataset): + def __len__(self): + return 1 + + def __getitem__(self, i): + return torch.zeros(3, 224, 224), 0 + + class FarDS(Dataset): + def __len__(self): + return 1 + + def __getitem__(self, i): + return torch.zeros(3, 224, 224), 0 + + near_list = [NearDS(), NearDS()] + far_list = [FarDS()] + + dm = ImageNetDataModule( + root=tmp_path, + batch_size=8, + eval_ood=True, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + near_ood_datasets=near_list, + far_ood_datasets=far_list, + ) + dm.setup("test") + + assert dm.near_oods is near_list + assert dm.far_oods is far_list + assert all(hasattr(ds, "dataset_name") for ds in dm.near_oods) + assert all(hasattr(ds, "dataset_name") for ds in dm.far_oods) + assert {ds.dataset_name for ds in dm.near_oods} == {"neards"} + assert {ds.dataset_name for ds in dm.far_oods} == {"fards"} + + def test_near_far_type_errors(self, monkeypatch, tmp_path): + # Avoid split file access so we can reach the TypeError branches + mod_name = ImageNetDataModule.__module__ + monkeypatch.setattr( + f"{mod_name}.download_and_extract_hf_dataset", + self._fake_download_and_extract, + raising=True, + ) + monkeypatch.setattr( + f"{mod_name}.download_and_extract_splits_from_hf", + self._fake_download_and_extract_splits_from_hf, + raising=True, + ) + monkeypatch.setattr( + f"{mod_name}.FileListDataset", + self._DummyFileListDataset, + raising=True, + ) + monkeypatch.setattr( + "torch_uncertainty.datamodules.classification.imagenet.get_ood_datasets", + self._fake_get_ood_datasets, + ) + + dm_bad_near = ImageNetDataModule( + root=tmp_path, + batch_size=8, + eval_ood=True, + near_ood_datasets=[123], + test_transform=nn.Identity(), + ) + with pytest.raises(TypeError, match="near_ood_datasets.*Dataset"): + dm_bad_near.setup("test") + + dm_bad_far = ImageNetDataModule( + root=tmp_path, + batch_size=8, + eval_ood=True, + far_ood_datasets=["nope"], + test_transform=nn.Identity(), + ) + with pytest.raises(TypeError, match="far_ood_datasets.*Dataset"): + dm_bad_far.setup("test") + + def test_train_dataloader_success_and_missing(self, monkeypatch, tmp_path): + mod_name = ImageNetDataModule.__module__ + monkeypatch.setattr( + f"{mod_name}.download_and_extract_hf_dataset", + self._fake_download_and_extract, + raising=True, + ) + monkeypatch.setattr( + f"{mod_name}.download_and_extract_splits_from_hf", + self._fake_download_and_extract_splits_from_hf, + raising=True, + ) + monkeypatch.setattr(f"{mod_name}.ImageFolder", self._DummyImageFolder, raising=True) + + dm = ImageNetDataModule( + root=tmp_path, + batch_size=4, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + ) + dm.data_dir = str(tmp_path) + (tmp_path / "train").mkdir(parents=True, exist_ok=True) + loader = dm.train_dataloader() + batch = next(iter(loader)) + assert isinstance(batch, list | tuple) + assert len(batch) == 2 + + dm2 = ImageNetDataModule( + root=tmp_path / "other", + batch_size=4, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + ) + dm2.data_dir = str(tmp_path / "no_train_here") + with pytest.raises(RuntimeError, match="ImageNet training data not found"): + dm2.train_dataloader() + + def test_get_indices_without_ood_or_shift(self, tmp_path): + dm = ImageNetDataModule( + root=tmp_path, + batch_size=8, + eval_ood=False, + eval_shift=False, + ) + idx = dm.get_indices() + assert idx["test"] == [0] + assert idx["test_ood"] == [] + assert idx["val_ood"] == [] + assert idx["near_oods"] == [] + assert idx["far_oods"] == [] + assert idx["shift"] == [] + + def test_tta_wraps_ood_sets(self, monkeypatch, tmp_path): + mod_name = ImageNetDataModule.__module__ + + monkeypatch.setattr( + f"{mod_name}.download_and_extract_hf_dataset", + self._fake_download_and_extract, + raising=True, + ) + monkeypatch.setattr( + f"{mod_name}.download_and_extract_splits_from_hf", + self._fake_download_and_extract_splits_from_hf, + raising=True, + ) + monkeypatch.setattr( + f"{mod_name}.FileListDataset", + self._DummyFileListDataset, + raising=True, + ) + monkeypatch.setattr( + "torch_uncertainty.datamodules.classification.imagenet.get_ood_datasets", + self._fake_get_ood_datasets, + ) + + dm = ImageNetDataModule( + root=tmp_path, + batch_size=8, + eval_ood=True, + num_tta=2, + train_transform=nn.Identity(), + test_transform=nn.Identity(), + num_workers=0, + persistent_workers=False, + pin_memory=False, + ) + + dm.setup("test") + + val_ood_wrapped = dm.get_val_ood_set() + test_ood_wrapped = dm.get_test_ood_set() + near_wrapped = dm.get_near_ood_set() + far_wrapped = dm.get_far_ood_set() + + assert len(val_ood_wrapped) == len(dm.val_ood) * dm.num_tta + assert len(test_ood_wrapped) == len(dm.test_ood) * dm.num_tta + assert all( + len(w) == len(b) * dm.num_tta for w, b in zip(near_wrapped, dm.near_oods, strict=False) + ) + assert all( + len(w) == len(b) * dm.num_tta for w, b in zip(far_wrapped, dm.far_oods, strict=False) + ) + + def _assert_first_block_repeat(wrapped_ds, num_tta: int): + if len(wrapped_ds) == 0 or num_tta < 2: + return + x0, y0 = wrapped_ds[0] + x1, y1 = wrapped_ds[1] + assert y0 == y1 + if torch.is_tensor(x0) and torch.is_tensor(x1): + assert x0.shape == x1.shape + else: + assert type(x0) is type(x1) + if hasattr(x0, "size") and hasattr(x1, "size"): + assert x0.size == x1.size + + _assert_first_block_repeat(val_ood_wrapped, dm.num_tta) + _assert_first_block_repeat(test_ood_wrapped, dm.num_tta) + for w in near_wrapped: + _assert_first_block_repeat(w, dm.num_tta) + for w in far_wrapped: + _assert_first_block_repeat(w, dm.num_tta) diff --git a/tests/datamodules/classification/test_imagenet200.py b/tests/datamodules/classification/test_imagenet200.py new file mode 100644 index 00000000..96d61c62 --- /dev/null +++ b/tests/datamodules/classification/test_imagenet200.py @@ -0,0 +1,259 @@ +from pathlib import Path + +import pytest +import torch +from torch.utils.data import Dataset + +from torch_uncertainty.datamodules import ImageNet200DataModule + + +class TinyImgDataset(Dataset): + """A tiny image dataset returning zeros.""" + + def __init__(self, n=3): + """Initialize dataset with n samples.""" + self.n = n + + def __len__(self): + """Return the number of samples.""" + return self.n + + def __getitem__(self, idx): + """Return a (C,H,W) tensor and a label.""" + return torch.zeros(3, 224, 224), 0 + + +class DummyFileListDataset(Dataset): + """Stand-in for FileListDataset.""" + + def __init__(self, root, list_file, transform): + """Initialize with a fixed TinyImgDataset backend.""" + self.ds = TinyImgDataset() + + def __len__(self): + """Return the number of samples.""" + return len(self.ds) + + def __getitem__(self, i): + """Return a sample from the inner dataset.""" + return self.ds[i] + + +class DummyImageFolder(Dataset): + """Stand-in for torchvision.datasets.ImageFolder.""" + + def __init__(self, root, transform=None): + """Initialize with a fixed TinyImgDataset backend and optional transform.""" + self.ds = TinyImgDataset() + self.transform = transform + + def __len__(self): + """Return the number of samples.""" + return len(self.ds) + + def __getitem__(self, i): + """Return (image, label), applying transform if provided.""" + x, y = self.ds[i] + return (self.transform(x) if self.transform is not None else x), y + + +def _fake_download_and_extract(name, dest_root): # noqa: ARG001 + """Return the destination root unchanged (avoid I/O).""" + return str(dest_root) + + +def _fake_download_and_extract_splits_from_hf(root): + """Return a Path to the provided root (avoid I/O).""" + return Path(root) + + +def _fake_get_ood_datasets(**_): + """Return tiny OOD datasets and minimal defaults (no I/O).""" + test_ood = TinyImgDataset(2) + val_ood = TinyImgDataset(2) + near_default = { + "near1": TinyImgDataset(1), + "near2": TinyImgDataset(1), + } + far_default = { + "far1": TinyImgDataset(1), + "far2": TinyImgDataset(1), + "far3": TinyImgDataset(1), + } + return test_ood, val_ood, near_default, far_default + + +def test_get_indices_no_ood_and_test_dataloader(monkeypatch, tmp_path): + """Covers get_indices() mapping when eval_ood=False and basic test loader count.""" + mod = ImageNet200DataModule.__module__ + + monkeypatch.setattr( + f"{mod}.download_and_extract_hf_dataset", _fake_download_and_extract, raising=True + ) + monkeypatch.setattr( + f"{mod}.download_and_extract_splits_from_hf", + _fake_download_and_extract_splits_from_hf, + raising=True, + ) + monkeypatch.setattr(f"{mod}.FileListDataset", DummyFileListDataset, raising=True) + + dm = ImageNet200DataModule( + root=tmp_path, + batch_size=32, + eval_ood=False, + eval_shift=False, + basic_augment=False, + num_workers=0, + persistent_workers=False, + pin_memory=False, + ) + + dm.prepare_data() + dm.setup("fit") + dm.setup("test") + + loaders = dm.test_dataloader() + assert isinstance(loaders, list) + assert len(loaders) == 1 + + idx = dm.get_indices() + assert "test" in idx + assert idx["test"] == [0] + assert idx["test_ood"] == [] + assert idx["val_ood"] == [] + assert idx["near_oods"] == [] + assert idx["far_oods"] == [] + assert idx["shift"] == [] + + +def test_train_dataloader_success_and_failure(monkeypatch, tmp_path): + """Cover train_dataloader() success with a fake train/ dir and failure when missing.""" + mod = ImageNet200DataModule.__module__ + + monkeypatch.setattr( + f"{mod}.download_and_extract_hf_dataset", _fake_download_and_extract, raising=True + ) + monkeypatch.setattr( + f"{mod}.download_and_extract_splits_from_hf", + _fake_download_and_extract_splits_from_hf, + raising=True, + ) + monkeypatch.setattr(f"{mod}.FileListDataset", DummyFileListDataset, raising=True) + monkeypatch.setattr(f"{mod}.ImageFolder", DummyImageFolder, raising=True) + + dm = ImageNet200DataModule( + root=tmp_path, + batch_size=8, + basic_augment=False, + num_workers=0, + persistent_workers=False, + pin_memory=False, + ) + + dm.prepare_data() + dm.setup("fit") + dm.setup("test") + + data_dir = tmp_path / "imagenet_fake" + train_dir = data_dir / "train" + train_dir.mkdir(parents=True, exist_ok=True) + + dm.data_dir = str(data_dir) + + loader = dm.train_dataloader() + batch = next(iter(loader)) + + assert isinstance(batch, list | tuple) + assert len(batch) == 2 + + x, y = batch[0], batch[1] + + assert torch.is_tensor(x) + assert x.ndim == 4 + assert torch.is_tensor(y) + assert y.ndim in (0, 1) + + dm.data_dir = str(tmp_path / "no_train_here") + with pytest.raises(RuntimeError, match="ImageNet training data not found"): + dm.train_dataloader() + + +def test_user_supplied_near_far_ood_instances_and_typecheck(monkeypatch, tmp_path): + """Exercise user-provided OOD lists: type checks and successful overrides.""" + mod = ImageNet200DataModule.__module__ + + monkeypatch.setattr( + f"{mod}.download_and_extract_hf_dataset", _fake_download_and_extract, raising=True + ) + monkeypatch.setattr( + f"{mod}.download_and_extract_splits_from_hf", + _fake_download_and_extract_splits_from_hf, + raising=True, + ) + monkeypatch.setattr(f"{mod}.FileListDataset", DummyFileListDataset, raising=True) + monkeypatch.setattr(f"{mod}.get_ood_datasets", _fake_get_ood_datasets, raising=True) + + dm_bad = ImageNet200DataModule( + root=tmp_path, + batch_size=16, + eval_ood=True, + basic_augment=False, + near_ood_datasets=[123, "not a dataset"], # invalid + num_workers=0, + persistent_workers=False, + pin_memory=False, + ) + with pytest.raises(TypeError, match="near_ood_datasets must be Dataset objects"): + dm_bad.setup("test") + + dm_bad2 = ImageNet200DataModule( + root=tmp_path, + batch_size=16, + eval_ood=True, + basic_augment=False, + far_ood_datasets=[object()], + num_workers=0, + persistent_workers=False, + pin_memory=False, + ) + with pytest.raises(TypeError, match="far_ood_datasets must be Dataset objects"): + dm_bad2.setup("test") + + near_custom = [TinyImgDataset(2), TinyImgDataset(3)] + far_custom = [TinyImgDataset(1)] + + dm = ImageNet200DataModule( + root=tmp_path, + batch_size=16, + eval_ood=True, + basic_augment=False, + near_ood_datasets=near_custom, + far_ood_datasets=far_custom, + num_workers=0, + persistent_workers=False, + pin_memory=False, + ) + + dm.setup("test") + + assert hasattr(dm, "near_oods") + assert dm.near_oods is near_custom + assert hasattr(dm, "far_oods") + assert dm.far_oods is far_custom + + loaders = dm.test_dataloader() + expected = ( + 1 + 1 + 1 + len(near_custom) + len(far_custom) + ) # ID + test_ood + val_ood + near + far + assert isinstance(loaders, list) + assert len(loaders) == expected + + idx = dm.get_indices() + assert idx["test"] == [0] + assert idx["test_ood"] == [1] + assert idx["val_ood"] == [2] + start_near = 3 + assert idx["near_oods"] == list(range(start_near, start_near + len(near_custom))) + start_far = start_near + len(near_custom) + assert idx["far_oods"] == list(range(start_far, start_far + len(far_custom))) + assert idx["shift"] == [] diff --git a/tests/datamodules/classification/test_sst2.py b/tests/datamodules/classification/test_sst2.py new file mode 100644 index 00000000..c524f637 --- /dev/null +++ b/tests/datamodules/classification/test_sst2.py @@ -0,0 +1,253 @@ +import pytest +import torch +from torch.utils.data import Dataset + +from torch_uncertainty.datamodules.classification.sst2 import Sst2DataModule + + +class DummySplit: + def __init__(self, rows): + """HF-like split that supports map()/set_format() and indexing.""" + self._rows = list(rows) + self._keep_columns = None + + @property + def column_names(self): + """Return all column names present in the split.""" + names = set() + for r in self._rows: + names |= set(r.keys()) + return list(names) + + def map(self, fn, batched=True, remove_columns=()): + """Apply a batched mapping function and drop columns.""" + if not batched: + raise NotImplementedError("Only batched=True supported in stub") + + batch = {name: [r.get(name) for r in self._rows] for name in self.column_names} + + new_cols = fn(batch) # dict of lists + + keep = {k: batch[k] for k in batch if k not in set(remove_columns)} + keep.update(new_cols) + + n = len(next(iter(keep.values()))) if keep else len(self._rows) + new_rows = [{k: keep[k][i] for k in keep} for i in range(n)] + return DummySplit(new_rows) + + def set_format(self, fmt="torch", columns=None, **kwargs): + """Mimic HF set_format: keep only requested columns on __getitem__. + + Accepts both fmt=... and type=... (HF uses type=). + """ + if "type" in kwargs and fmt == "torch": + fmt = kwargs["type"] + self._keep_columns = list(columns) if columns is not None else None + return self + + def __len__(self): + """Number of rows in the split.""" + return len(self._rows) + + def __getitem__(self, idx): + """Get a row (optionally filtered to requested columns).""" + r = self._rows[idx] + if self._keep_columns is not None: + r = {k: r[k] for k in self._keep_columns if k in r} + return r + + +class DummyDatasetDict(dict): + """HF-like container with split keys.""" + + def map(self, fn, batched=True, remove_columns=()): + """Apply map to each contained split.""" + return DummyDatasetDict( + {k: v.map(fn, batched=batched, remove_columns=remove_columns) for k, v in self.items()} + ) + + def set_format(self, fmt="torch", columns=None, **kwargs): + """Apply set_format to each contained split. + + Accepts both fmt=... and type=... (HF uses type=). + """ + if "type" in kwargs and fmt == "torch": + fmt = kwargs["type"] + for v in self.values(): + v.set_format(fmt=fmt, columns=columns) + return self + + +class DummyTokenizer: + """Minimal tokenizer stub returning fixed-length ids and masks.""" + + def __init__(self, max_id=1000): + self.max_id = max_id + + def __call__(self, *args, max_length=128, truncation=True, padding="max_length"): + if len(args) == 1: + texts = args[0] + elif len(args) == 2: + t1, t2 = args + texts = [f"{a} {b}" for a, b in zip(t1, t2, strict=False)] + else: + raise ValueError("Unexpected tokenizer inputs") + + n = len(texts) + out_ids, out_mask = [], [] + for i in range(n): + seq_len = min(16, max_length) + ids = [(i + j) % self.max_id for j in range(seq_len)] + mask = [1] * seq_len + if seq_len < max_length: + pad = max_length - seq_len + ids += [0] * pad + mask += [0] * pad + out_ids.append(ids) + out_mask.append(mask) + return {"input_ids": out_ids, "attention_mask": out_mask} + + @classmethod + def from_pretrained(cls, *_, **__): + return cls() + + +@pytest.fixture +def patch_hf(monkeypatch): + """Patch tokenizer and load_dataset in the module where Sst2DataModule lives.""" + monkeypatch.setattr( + f"{Sst2DataModule.__module__}.AutoTokenizer", + DummyTokenizer, + raising=True, + ) + + def _fake_load_dataset(path, name=None, split=None, download_config=None): # noqa: ARG001 + if path == "glue" and name == "sst2": + train = DummySplit( + [ + {"sentence": "great movie", "label": 1, "idx": 0}, + {"sentence": "bad film", "label": 0, "idx": 1}, + {"sentence": "okay", "label": 1, "idx": 2}, + ] + ) + val = DummySplit( + [ + {"sentence": "not good", "label": 0, "idx": 3}, + {"sentence": "wonderful", "label": 1, "idx": 4}, + ] + ) + if split is None: + return DummyDatasetDict(train=train, validation=val) + return {"train": train, "validation": val}[split] + + # Near OOD + if path == "yelp_polarity": + return DummySplit([{"text": "food was amazing", "label": 1}]) + if path == "amazon_polarity": + return DummySplit([{"content": "terrible product", "label": 0}]) + + # Far OOD + if path == "ag_news": + return DummySplit([{"text": "stocks fell today", "label": 2}]) + if path == "SetFit/20_newsgroups": + return DummySplit([{"text": "comp.graphics topic", "label": 5}]) + if path == "SetFit/TREC-QC": + return DummySplit([{"text": "what is AI?", "label": 0}]) + if path == "glue" and name == "mnli": + return DummySplit([{"premise": "cats sleep", "hypothesis": "animals rest", "label": 1}]) + if path == "glue" and name == "rte": + return DummySplit( + [{"sentence1": "A man runs", "sentence2": "A person jogs", "label": 1}] + ) + if path == "wmt16" and name == "ro-en": + return DummySplit([{"translation": {"ro": "salut", "en": "hello"}, "label": 0}]) + + raise KeyError(f"Unhandled dataset: {(path, name, split)}") + + monkeypatch.setattr( + f"{Sst2DataModule.__module__}.load_dataset", + _fake_load_dataset, + raising=True, + ) + + +class TestSst2DataModule: + def test_id_only(self, patch_hf): + dm = Sst2DataModule( + model_name="bert-base-uncased", + max_len=32, + batch_size=4, + local_files_only=True, + eval_ood=False, + num_workers=0, + persistent_workers=False, + ) + + dm.prepare_data() + dm.setup("fit") + dm.setup("test") + + train_loader = dm.train_dataloader() + val_loader = dm.val_dataloader() + test_loaders = dm.test_dataloader() + + assert isinstance(dm.test, Dataset) + assert len(test_loaders) == 1 # only ID + + xb, yb = next(iter(train_loader)) + assert {"input_ids", "attention_mask"} <= set(xb.keys()) + assert yb.dtype == torch.long + + idx = dm.get_indices() + assert idx["test"] == [0] + assert idx["test_ood"] == [] + assert idx["val_ood"] == [] + assert idx["near_oods"] == [] + assert idx["far_oods"] == [] + assert idx["shift"] == [] + + _ = next(iter(val_loader)) + _ = next(iter(test_loaders[0])) + + def test_with_ood(self, patch_hf): + dm = Sst2DataModule( + model_name="bert-base-uncased", + max_len=32, + batch_size=4, + local_files_only=True, + eval_ood=True, + num_workers=0, + persistent_workers=False, + ) + + dm.prepare_data() + dm.setup("fit") + dm.setup("test") + + assert isinstance(dm.test, Dataset) + assert len(dm.near_oods) == 2 # yelp + amazon + assert len(dm.far_oods) == 6 # ag_news, 20newsg, trec_qc, mnli_mm, rte, wmt16_en + + for ds in [*dm.near_oods, *dm.far_oods]: + assert isinstance(ds, Dataset) + assert hasattr(ds, "dataset_name") + assert isinstance(ds.dataset_name, str) + assert ds.dataset_name + + loaders = dm.test_dataloader() + expected = 1 + 1 + len(dm.near_oods) + len(dm.far_oods) + assert len(loaders) == expected + + idx = dm.get_indices() + assert idx["test"] == [0] + assert idx["test_ood"] == [1] + assert idx["val_ood"] == [] + assert idx["near_oods"] == list(range(2, 2 + len(dm.near_oods))) + assert idx["far_oods"] == list( + range(2 + len(dm.near_oods), 2 + len(dm.near_oods) + len(dm.far_oods)) + ) + assert idx["shift"] == [] + + (xb, yb) = next(iter(loaders[2])) + assert {"input_ids", "attention_mask"} <= set(xb.keys()) + assert yb.dtype == torch.long diff --git a/tests/metrics/classification/test_calibration.py b/tests/metrics/classification/test_calibration.py index 99af5428..512f6b66 100644 --- a/tests/metrics/classification/test_calibration.py +++ b/tests/metrics/classification/test_calibration.py @@ -50,7 +50,7 @@ def test_errors(self) -> None: with pytest.raises(TypeError, match="is expected to be `int`"): CalibrationError(task="multiclass", num_classes=None) with pytest.raises( - ValueError, match="`n_bins` does not exist in TorchUncertainty, use `num_bins`." + ValueError, match=r"`n_bins` does not exist(?: in TorchUncertainty)?, use `num_bins`\." ): CalibrationError(task="multiclass", num_classes=2, n_bins=1) diff --git a/tests/ood/__init__.py b/tests/ood/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ood/test_nets.py b/tests/ood/test_nets.py new file mode 100644 index 00000000..5050feeb --- /dev/null +++ b/tests/ood/test_nets.py @@ -0,0 +1,99 @@ +import importlib + +import numpy as np +import torch +from torch import nn + +from torch_uncertainty.ood.nets import ASHNet, ReactNet, ScaleNet + + +class _TinyBackbone(nn.Module): + """Minimal CNN backbone for testing.""" + + def __init__(self, in_ch=3, num_classes=3): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, 8, kernel_size=3, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool2d(1), # -> (B, 8, 1, 1) + ) + self.feature_size = 8 + self.fc = nn.Linear(self.feature_size, num_classes) + + def forward(self, x, return_feature: bool = False, return_feature_list: bool = False): + feat = self.conv(x).flatten(1) # (B, 8) + logits = self.fc(feat) # (B, C) + if return_feature: + return logits, feat + return logits + + def get_fc_layer(self): + return self.fc + + +def _rand_batch(b=4, c=3, h=4, w=4, seed=0, eps=1e-3): + g = torch.Generator().manual_seed(seed) + x = torch.rand((b, c, h, w), generator=g) + eps + return x.float() + + +def test_ash_helpers_and_nets_cpu_exhaustive(): + device = torch.device("cpu") + x = _rand_batch().to(device) + + ash_module = importlib.import_module(ASHNet.__module__) + ash_b = getattr(ash_module, "ash_b", None) + ash_p = getattr(ash_module, "ash_p", None) + ash_s = getattr(ash_module, "ash_s", None) + ash_rand = getattr(ash_module, "ash_rand", None) + + assert callable(ash_b), "ash_b not found next to ASHNet" + assert callable(ash_p), "ash_p not found next to ASHNet" + assert callable(ash_s), "ash_s not found next to ASHNet" + assert callable(ash_rand), "ash_rand not found next to ASHNet" + + for pct in (0, 50): + t = x.clone() + outs = ( + ash_b(t.clone(), percentile=pct), + ash_p(t.clone(), percentile=pct), + ash_s(t.clone(), percentile=pct), + ash_rand(t.clone(), percentile=pct, r1=0.0, r2=1.0), + ) + for out in outs: + assert out.shape == t.shape + assert torch.isfinite(out).all(), f"Non-finite values at percentile {pct}" + + pct = 100 + t = x.clone() + outs_100 = ( + ash_b(t.clone(), percentile=pct), + ash_p(t.clone(), percentile=pct), + ash_s(t.clone(), percentile=pct), + ash_rand(t.clone(), percentile=pct, r1=0.0, r2=1.0), + ) + for out in outs_100: + assert out.shape == t.shape + + backbone = _TinyBackbone(in_ch=3, num_classes=3).to(device) + + ash_net = ASHNet(backbone) + y_ash = ash_net.forward_threshold(x, percentile=70) + assert y_ash.shape == (x.size(0), 3) + assert torch.isfinite(y_ash).all() + + react_net = ReactNet(backbone) + y_react = react_net.forward_threshold(x, threshold=1.0) + assert y_react.shape == (x.size(0), 3) + assert torch.isfinite(y_react).all() + + scale_net = ScaleNet(backbone) + y_scale = scale_net.forward_threshold(x, percentile=65) + assert y_scale.shape == (x.size(0), 3) + assert torch.isfinite(y_scale).all() + + w, b = scale_net.get_fc() + assert isinstance(w, np.ndarray) + assert isinstance(b, np.ndarray) + assert w.shape == (3, backbone.feature_size) + assert b.shape == (3,) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 74d9123d..e6a0bb8d 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -1,22 +1,28 @@ +import logging +import types from pathlib import Path import pytest +import torch from torch import nn +from torch.utils.data import Dataset from tests._dummies import ( DummyClassificationBaseline, DummyClassificationDataModule, dummy_model, + dummy_ood_model, ) from torch_uncertainty import TUTrainer from torch_uncertainty.losses import DECLoss, ELBOLoss -from torch_uncertainty.ood_criteria import ( +from torch_uncertainty.ood.ood_criteria import ( EntropyCriterion, - PostProcessingCriterion, + MaxSoftmaxCriterion, ) from torch_uncertainty.post_processing import ConformalClsTHR from torch_uncertainty.routines import ClassificationRoutine from torch_uncertainty.transforms import RepeatTarget +from torch_uncertainty.utils.evaluation_loop import TUEvaluationLoop class TestClassification: @@ -355,7 +361,6 @@ def test_one_estimator_conformal(self) -> None: loss=None, num_classes=3, post_processing=ConformalClsTHR(alpha=0.1), - ood_criterion=PostProcessingCriterion(), eval_ood=True, ) trainer.test(routine, dm) @@ -377,108 +382,476 @@ def test_one_estimator_conformal(self) -> None: ) trainer.test(routine, dm) - def test_classification_failures(self) -> None: - # num_classes - with pytest.raises(ValueError): - ClassificationRoutine(num_classes=0, model=nn.Module(), loss=None) - # single & MI - with pytest.raises(ValueError): - ClassificationRoutine( - num_classes=10, - model=nn.Module(), - loss=None, - is_ensemble=False, - ood_criterion="mutual_information", + OOD_CRITS = [ + "scale", + "ash", + "react", + "adascale_a", + "vim", + "odin", + "knn", + "gen", + "nnguide", + ] + + @pytest.mark.parametrize("crit", OOD_CRITS) + def test_all_other_ood_criteria_with_dummy_ood_model(self, crit, monkeypatch): + device = torch.device("cpu") + + monkeypatch.setattr( + torch.Tensor, + "cuda", + lambda self, *a, **k: self.to(device), # noqa: ARG005 + raising=False, + ) + monkeypatch.setattr(nn.Module, "cuda", lambda self, *a, **k: self.to(device), raising=False) # noqa: ARG005 + + trainer = TUTrainer( + accelerator="cpu", + devices=1, + inference_mode=False, + num_sanity_val_steps=0, + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + enable_checkpointing=False, + logger=False, + ) + + class _TensorDS(Dataset): + def __init__(self, x, y, name="tensor_ds"): + self.x, self.y, self.dataset_name = x, y, name + + def __len__(self): + return self.x.shape[0] + + def __getitem__(self, i): + return self.x[i], self.y[i] + + def _mk_split(n, c, h, w, num_classes, shift=0.0, seed=0): + g = torch.Generator().manual_seed(seed) + x = (torch.rand((n, c, h, w), generator=g) + shift).clamp(0, 1) + y = (torch.arange(n) % num_classes).long() + return x.float(), y + + dm = DummyClassificationDataModule( + root=Path(), + batch_size=8, + num_classes=3, + num_images=64, + num_workers=0, + eval_ood=True, + persistent_workers=False, + ) + + def patched_setup(self, stage=None): + self.num_channels = 3 + h = w = self.image_size + n = self.num_images + + if stage in (None, "fit"): + x_tr, y_tr = _mk_split(n, 3, h, w, self.num_classes, 0.0, 123) + x_va, y_va = _mk_split(n, 3, h, w, self.num_classes, 0.0, 456) + self.train = _TensorDS(x_tr, y_tr, "train") + self.val = _TensorDS(x_va, y_va, "val") + + if stage in (None, "test"): + x_te, y_te = _mk_split(n, 3, h, w, self.num_classes, 0.0, 789) + self.test = _TensorDS(x_te, y_te, "test") + if self.eval_ood: + x_vo, y_vo = _mk_split(n, 3, h, w, self.num_classes, 0.10, 321) + self.val_ood = _TensorDS(x_vo, y_vo, "val_ood") + x_near, y_near = _mk_split(n, 3, h, w, self.num_classes, 0.15, 654) + self.near_oods = [_TensorDS(x_near, y_near, "near")] + x_far, y_far = _mk_split(n, 3, h, w, self.num_classes, 0.35, 987) + self.far_oods = [_TensorDS(x_far, y_far, "far")] + + if self.eval_shift: + x_sh, y_sh = _mk_split(n, 3, h, w, self.num_classes, 0.20, 111) + self.shift = _TensorDS(x_sh, y_sh, "shift") + self.shift_severity = 1 + + monkeypatch.setattr(dm, "setup", patched_setup.__get__(dm, type(dm)), raising=True) + + model = dummy_ood_model(in_channels=3, feat_dim=4096, num_classes=dm.num_classes).to(device) + + routine = ClassificationRoutine( + model=model, + loss=None, + num_classes=dm.num_classes, + eval_ood=True, + ood_criterion=crit, + log_plots=False, + ) + + c = routine.ood_criterion + if hasattr(c, "args_dict"): + if crit in ("scale", "ash", "react"): + c.args_dict = {"percentile": [70]} + elif crit == "adascale_a": + c.args_dict = { + "percentile": [(40, 60)], + "k1": [1], + "k2": [1], + "lmbda": [0.1], + "o": [0.05], + } + elif crit == "vim": + safe_dim = min(64, getattr(model, "feature_size", 256) - 1) + c.args_dict = {"dim": [safe_dim]} + c.dim = safe_dim + elif crit == "odin": + c.args_dict = {"temperature": [1.0], "noise": [0.0014]} + elif crit == "knn": + c.args_dict = {"K": [5]} + elif crit == "gen": + # ensure m ≤ num_classes to avoid degenerate slices + c.gamma = getattr(c, "gamma", 0.1) + c.m = min(getattr(c, "m", 10), dm.num_classes) + c.args_dict = {"gamma": [c.gamma], "m": [c.m]} + elif crit == "nnguide": + c.args_dict = {"K": [5], "alpha": [0.5]} + c.hyperparam_search_done = False + + trainer.test(routine, dm) + + if hasattr(c, "args_dict"): + assert getattr(c, "hyperparam_search_done", False), ( + f"Hyperparam search did not complete for '{crit}'." ) + for needs_setup in {"react", "adascale_a", "vim", "knn", "nnguide"}: + if crit == needs_setup: + assert getattr(c, "setup_flag", False), f"Setup not executed for '{crit}'." - with pytest.raises(ValueError): - ClassificationRoutine( - num_classes=10, - model=nn.Module(), - loss=None, - is_ensemble=False, - ood_criterion=32, + def test_setup_logs_when_no_train_loader(self, caplog, monkeypatch): + dm = DummyClassificationDataModule( + root=Path(), + batch_size=4, + num_classes=3, + num_images=16, + eval_ood=True, + ) + + def _raise_train_loader(*_a, **_k): + raise RuntimeError("no train loader") + + monkeypatch.setattr( + ClassificationRoutine, "_hyperparam_search_ood", lambda _self: None, raising=True + ) + monkeypatch.setattr(dm, "train_dataloader", _raise_train_loader, raising=True) + + model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3) + routine = ClassificationRoutine( + model=model, + loss=None, + num_classes=3, + eval_ood=True, + ) + routine.ood_criterion = MaxSoftmaxCriterion() # no setup() side-effects + + routine.trainer = types.SimpleNamespace(datamodule=dm) + + with caplog.at_level(logging.INFO): + routine.setup("test") + assert any("No train loader detected" in r.message for r in caplog.records) + + def test_create_near_far_metric_dicts_non_ensemble(self, capsys): + model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3) + routine = ClassificationRoutine( + model=model, loss=None, num_classes=3, eval_ood=True, is_ensemble=False + ) + routine.ood_criterion = MaxSoftmaxCriterion() + + x = torch.rand(4, 3, 8, 8) + y = torch.tensor([0, 1, 2, 0]) + + class _DS: + def __init__(self, name): + self.dataset_name = name + + routine.trainer = types.SimpleNamespace( + datamodule=types.SimpleNamespace( + get_indices=lambda: {"val_ood": 9, "near_oods": [2], "far_oods": [3], "shift": []}, + near_oods=[_DS("nearX")], + far_oods=[_DS("farY")], ) + ) - with pytest.raises(ValueError): + routine.test_step((x, y), batch_idx=0, dataloader_idx=2) # near + assert "nearX" in routine.test_ood_metrics_near + + routine.test_step((x, y), batch_idx=0, dataloader_idx=3) # far + assert "farY" in routine.test_ood_metrics_far + + fake_results = [ + { + "ood_near_nearX_auroc": torch.tensor(0.91), + "ood_near_nearX_fpr95": torch.tensor(0.09), + "ood_near_nearX_aupr": torch.tensor(0.88), + "ood_near_dsB_auroc": torch.tensor(0.92), + "ood_near_dsB_fpr95": torch.tensor(0.08), + "ood_near_dsB_aupr": torch.tensor(0.86), + "ood_far_farY_auroc": torch.tensor(0.81), + "ood_far_farY_fpr95": torch.tensor(0.19), + "ood_far_farY_aupr": torch.tensor(0.72), + "ood_far_dsD_auroc": torch.tensor(0.79), + "ood_far_dsD_fpr95": torch.tensor(0.21), + "ood_far_dsD_aupr": torch.tensor(0.70), + } + ] + TUEvaluationLoop._print_results(fake_results, stage="test") + out = capsys.readouterr().out + assert "OOD Results" in out + assert "NearOOD Average" in out + assert "FarOOD Average" in out + + def test_create_near_far_metric_dicts_ensemble_and_aggregator(self): + model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3) + routine = ClassificationRoutine( + model=model, loss=None, num_classes=3, eval_ood=True, is_ensemble=True + ) + routine.ood_criterion = MaxSoftmaxCriterion() + + x = torch.rand(4, 3, 8, 8) + y = torch.tensor([0, 1, 2, 0]) + + class _DS: + def __init__(self, name): + self.dataset_name = name + + routine.trainer = types.SimpleNamespace( + datamodule=types.SimpleNamespace( + get_indices=lambda: { + "val_ood": 9, + "near_oods": [5], + "far_oods": [6], + "shift": [7], + }, + near_oods=[_DS("n1")], + far_oods=[_DS("f1")], + ) + ) + + routine.test_step((x, y), batch_idx=0, dataloader_idx=1) # aggregator + assert "n1" in routine.test_ood_ens_metrics_near + assert "f1" in routine.test_ood_ens_metrics_far + + routine.test_step((x, y), batch_idx=0, dataloader_idx=5) # near + routine.test_step((x, y), batch_idx=0, dataloader_idx=6) # far + assert "n1" in routine.test_ood_ens_metrics_near + assert "f1" in routine.test_ood_ens_metrics_far + + def test_skip_when_val_ood_loader(self): + model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3) + routine = ClassificationRoutine(model=model, loss=None, num_classes=3, eval_ood=True) + routine.ood_criterion = MaxSoftmaxCriterion() + + routine.trainer = types.SimpleNamespace( + datamodule=types.SimpleNamespace( + get_indices=lambda: {"val_ood": 4, "near_oods": [], "far_oods": [], "shift": []} + ) + ) + x = torch.rand(2, 3, 8, 8) + y = torch.tensor([0, 1]) + routine.test_step((x, y), batch_idx=0, dataloader_idx=4) + + def test_init_metrics_creates_shift_ens_metrics_when_ensemble_and_eval_shift(self): + model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3) + routine = ClassificationRoutine( + model=model, loss=None, num_classes=3, eval_shift=True, is_ensemble=True + ) + assert hasattr(routine, "test_shift_ens_metrics") + + def test_shift_ens_update_path(self): + model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3) + routine = ClassificationRoutine( + model=model, loss=None, num_classes=3, eval_shift=True, is_ensemble=True + ) + routine.ood_criterion = MaxSoftmaxCriterion() + + x = torch.rand(4, 3, 8, 8) + y = torch.tensor([0, 1, 2, 0]) + + routine.trainer = types.SimpleNamespace( + datamodule=types.SimpleNamespace( + get_indices=lambda: {"val_ood": 99, "near_oods": [], "far_oods": [], "shift": [7]} + ) + ) + routine.test_step((x, y), batch_idx=0, dataloader_idx=7) + + def test_logs_when_eval_flags_mismatch_datamodule(self, caplog): + model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3) + routine = ClassificationRoutine( + model=model, loss=None, num_classes=3, eval_ood=False, eval_shift=False + ) + routine.ood_criterion = MaxSoftmaxCriterion() + + class _DM: + def get_indices(self): + return {"val_ood": 9, "near_oods": [2], "far_oods": [3], "shift": [4]} + + routine._trainer = types.SimpleNamespace(barebones=True, datamodule=_DM()) + + x = torch.rand(2, 3, 8, 8) + y = torch.tensor([0, 1]) + + with caplog.at_level(logging.INFO): + routine.test_step((x, y), batch_idx=0, dataloader_idx=0) + + assert any( + "`eval_ood` to `True` in the datamodule and not in the routine" in r.message + for r in caplog.records + ) + assert any( + "`eval_shift` to `True` in the datamodule and not in the routine" in r.message + for r in caplog.records + ) + + def test_guardrails__classification_routine_checks(self): + # ---- Local dummy models (minimal & fast) ---- + + class BareModel(nn.Module): + """No feats_forward, no classifier attrs.""" + + def __init__(self, num_classes=3): + super().__init__() + self.fc = nn.Linear(4, num_classes) + + def forward(self, x): + b = x.size(0) if hasattr(x, "size") else 2 + return self.fc(torch.zeros((b, 4))) + + class FeatsWithHead(nn.Module): + """Has feats_forward + classification_head.""" + + def __init__(self, num_classes=3): + super().__init__() + self.backbone = nn.Linear(4, 8) + self.classification_head = nn.Linear(8, num_classes) + + def feats_forward(self, x): + b = x.size(0) if hasattr(x, "size") else 2 + return self.backbone(torch.zeros((b, 4))) + + def forward(self, x): + return self.classification_head(self.feats_forward(x)) + + class FeatsWithLinear(nn.Module): + """Has feats_forward + linear.""" + + def __init__(self, num_classes=3): + super().__init__() + self.backbone = nn.Linear(4, 8) + self.linear = nn.Linear(8, num_classes) + + def feats_forward(self, x): + b = x.size(0) if hasattr(x, "size") else 2 + return self.backbone(torch.zeros((b, 4))) + + def forward(self, x): + return self.linear(self.feats_forward(x)) + + class FeatsNoHeadNoLinear(nn.Module): + """Has feats_forward but neither classification_head nor linear.""" + + def feats_forward(self, x): + b = x.size(0) if hasattr(x, "size") else 2 + return torch.zeros((b, 8)) + + def forward(self, x): + b = x.size(0) if hasattr(x, "size") else 2 + return torch.zeros((b, 3)) + + for crit in ("mutual_information", "variation_ratio"): + with pytest.raises(ValueError, match="mutual information|variation ratio"): + ClassificationRoutine( + model=BareModel(), + num_classes=3, + is_ensemble=False, + ood_criterion=crit, + ) + + with pytest.raises(NotImplementedError, match="Logit-based criteria"): ClassificationRoutine( - num_classes=10, - model=nn.Module(), - loss=None, - is_ensemble=False, - ood_criterion="other", + model=BareModel(), + num_classes=3, + is_ensemble=True, + ood_criterion="logit", ) - mixup_params = {"cutmix_alpha": -1} - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError, match="Grouping loss for ensembles"): ClassificationRoutine( - num_classes=10, - model=nn.Module(), - loss=None, - mixup_params=mixup_params, + model=FeatsWithHead(), + num_classes=3, + is_ensemble=True, + eval_grouping_loss=True, ) - with pytest.raises(ValueError, match="num_bins_cal_err must be at least 2, got"): + with pytest.raises(ValueError, match="positive integer"): ClassificationRoutine( - model=nn.Identity(), - num_classes=2, - loss=nn.CrossEntropyLoss(), - num_bins_cal_err=0, + model=BareModel(), + num_classes=0, ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="feats_forward"): ClassificationRoutine( - num_classes=10, - model=nn.Module(), - loss=None, + model=BareModel(), + num_classes=3, eval_grouping_loss=True, ) - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError, match="classification_head|linear"): ClassificationRoutine( - num_classes=10, - model=nn.Module(), - loss=None, - is_ensemble=True, + model=FeatsNoHeadNoLinear(), + num_classes=3, eval_grouping_loss=True, ) - model = dummy_model(1, 1, 0, with_feats=False) - with pytest.raises(ValueError): - ClassificationRoutine(num_classes=10, model=model, loss=None, eval_grouping_loss=True) - - with pytest.raises( - ValueError, - match="Mixup is not supported for ensembles at training time", - ): + with pytest.raises(ValueError, match="at least 2"): ClassificationRoutine( - num_classes=10, - model=nn.Module(), - loss=None, - mixup_params={"mixtype": "mixup"}, - format_batch_fn=RepeatTarget(2), + model=BareModel(), + num_classes=3, + num_bins_cal_err=1, ) - with pytest.raises( - ValueError, - match="Ensembles and post-processing methods cannot be used together. Raise an issue if needed.", - ): + with pytest.raises(ValueError, match="Mixup is not supported for ensembles"): ClassificationRoutine( - num_classes=10, - model=nn.Module(), - loss=None, + model=BareModel(), + num_classes=3, is_ensemble=True, - post_processing=nn.Module(), + mixup_params={"mixup_alpha": 1.0}, + format_batch_fn=RepeatTarget(num_repeats=2), ) - with pytest.raises( - ValueError, - match="You cannot set ood_criterion=PostProcessingCriterion when post_processing is None.", - ): + with pytest.raises(ValueError, match="Ensembles and post-processing"): ClassificationRoutine( - num_classes=10, - model=nn.Module(), - loss=None, - post_processing=None, - ood_criterion=PostProcessingCriterion(), + model=BareModel(), + num_classes=3, + is_ensemble=True, + post_processing=ConformalClsTHR(alpha=0.1), ) + + ClassificationRoutine( + model=BareModel(), + num_classes=3, + is_ensemble=False, + ood_criterion="msp", + eval_grouping_loss=False, + num_bins_cal_err=15, + mixup_params=None, + post_processing=None, + format_batch_fn=None, + ) + + ClassificationRoutine( + model=FeatsWithHead(), + num_classes=3, + eval_grouping_loss=True, + ) + + ClassificationRoutine( + model=FeatsWithLinear(), + num_classes=3, + eval_grouping_loss=True, + ) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index 4471dff9..472ff5d3 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -2,7 +2,7 @@ from typing import Literal from torch_uncertainty.models import deep_ensembles -from torch_uncertainty.ood_criteria import TUOODCriterion +from torch_uncertainty.ood.ood_criteria import TUOODCriterion from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.utils import get_version diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index cb3c254b..a1edecd5 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -12,7 +12,7 @@ packed_resnet, resnet, ) -from torch_uncertainty.ood_criteria import TUOODCriterion +from torch_uncertainty.ood.ood_criteria import TUOODCriterion from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 40c40db5..a192069a 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -8,7 +8,7 @@ packed_vgg, vgg, ) -from torch_uncertainty.ood_criteria import TUOODCriterion +from torch_uncertainty.ood.ood_criteria import TUOODCriterion from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import RepeatTarget diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index c37ea80f..bb5c4585 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -11,7 +11,7 @@ packed_wideresnet28x10, wideresnet28x10, ) -from torch_uncertainty.ood_criteria import TUOODCriterion +from torch_uncertainty.ood.ood_criteria import TUOODCriterion from torch_uncertainty.routines.classification import ( ClassificationRoutine, ) diff --git a/torch_uncertainty/datamodules/__init__.py b/torch_uncertainty/datamodules/__init__.py index ca56913a..d0d1da34 100644 --- a/torch_uncertainty/datamodules/__init__.py +++ b/torch_uncertainty/datamodules/__init__.py @@ -6,10 +6,12 @@ CIFAR100DataModule, DOTA2GamesDataModule, HTRU2DataModule, + ImageNet200DataModule, ImageNetDataModule, MNISTDataModule, OnlineShoppersDataModule, SpamBaseDataModule, + Sst2DataModule, TinyImageNetDataModule, UCIClassificationDataModule, UCRUEADataModule, diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 00c0e66d..5cb8045e 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -79,9 +79,10 @@ def __init__( self.pin_memory = pin_memory self.persistent_workers = persistent_workers - if batch_size % num_tta: + if batch_size % num_tta != 0: raise ValueError( - f"The number of Test-time augmentations num_tta should divide batch_size. Got {num_tta} and {batch_size}." + f"The number of Test-time augmentations num_tta should divide batch_size. " + f"Got num_tta={num_tta} and batch_size={batch_size}." ) self.num_tta = num_tta if postprocess_set == "test": @@ -109,11 +110,35 @@ def get_test_set(self) -> Dataset: return self.test def get_ood_set(self) -> Dataset: - """Get the shifted set.""" + """Get the ood set // legacy.""" if self.num_tta > 1: return TTADataset(self.ood, self.num_tta) return self.ood + def get_val_ood_set(self) -> Dataset: + """Get the shifted set.""" + if self.num_tta > 1: + return TTADataset(self.val_ood, self.num_tta) + return self.val_ood + + def get_test_ood_set(self) -> Dataset: + """Get the shifted set.""" + if self.num_tta > 1: + return TTADataset(self.test_ood, self.num_tta) + return self.test_ood + + def get_near_ood_set(self) -> Dataset: + """Get the near_ood sets.""" + if self.num_tta > 1: + return [TTADataset(ds, self.num_tta) for ds in self.near_oods] + return self.near_oods + + def get_far_ood_set(self) -> Dataset: + """Get the far_ood sets.""" + if self.num_tta > 1: + return [TTADataset(ds, self.num_tta) for ds in self.far_oods] + return self.far_oods + def get_shift_set(self) -> Dataset: """Get the shifted set.""" if self.num_tta > 1: @@ -153,7 +178,9 @@ def postprocess_dataloader(self) -> DataLoader: """ return self.val_dataloader() if self.postprocess_set == "val" else self.test_dataloader()[0] - def _data_loader(self, dataset: Dataset, training: bool, shuffle: bool = False) -> DataLoader: + def _data_loader( + self, dataset: Dataset, training: bool, shuffle: bool = False, drop_last=False + ) -> DataLoader: """Create a dataloader for a given dataset. Args: @@ -161,6 +188,8 @@ def _data_loader(self, dataset: Dataset, training: bool, shuffle: bool = False) training (bool): Whether it is a training or evaluation dataloader. shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + drop_last (bool, optional): Whether to drop the last incomplete batch + if the dataset size is not divisible by the batch size. Defaults to False. Return: DataLoader: Dataloader for the given dataset. @@ -172,6 +201,7 @@ def _data_loader(self, dataset: Dataset, training: bool, shuffle: bool = False) num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, + drop_last=drop_last, ) # These two functions have to be defined in each datamodule diff --git a/torch_uncertainty/datamodules/classification/__init__.py b/torch_uncertainty/datamodules/classification/__init__.py index 7eb86bd8..226a2228 100644 --- a/torch_uncertainty/datamodules/classification/__init__.py +++ b/torch_uncertainty/datamodules/classification/__init__.py @@ -2,7 +2,9 @@ from .cifar10 import CIFAR10DataModule from .cifar100 import CIFAR100DataModule from .imagenet import ImageNetDataModule +from .imagenet200 import ImageNet200DataModule from .mnist import MNISTDataModule +from .sst2 import Sst2DataModule from .tiny_imagenet import TinyImageNetDataModule from .uci import ( BankMarketingDataModule, diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 7446721e..100a9a07 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import Literal @@ -6,16 +7,23 @@ from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import DataLoader -from torchvision.datasets import CIFAR10, SVHN +from torch.utils.data import DataLoader, Dataset +from torchvision.datasets import CIFAR10 from torchvision.transforms import v2 from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR10C, CIFAR10H +from torch_uncertainty.datasets.ood.utils import get_ood_datasets from torch_uncertainty.datasets.utils import create_train_val_split from torch_uncertainty.transforms import Cutout +logging.basicConfig( + level=logging.INFO, + format="%(message)s", +) +logging.getLogger("faiss").setLevel(logging.WARNING) + class CIFAR10DataModule(TUDataModule): num_classes = 10 @@ -47,6 +55,8 @@ def __init__( num_dataloaders: int = 1, pin_memory: bool = True, persistent_workers: bool = True, + near_ood_datasets: list | None = None, + far_ood_datasets: list | None = None, ) -> None: """DataModule for CIFAR10. @@ -56,6 +66,8 @@ def __init__( eval_batch_size (int | None) : Number of samples per batch during evaluation (val and test). Set to batch_size if ``None``. Defaults to ``None``. eval_ood (bool): Whether to evaluate on out-of-distribution data. Defaults to ``False``. + near_ood_datasets (list, optional): list of near OOD dataset classes must be subclass of torch.utils.data.Dataset. Defaults to CIFAR-100, Tiny ImageNet (OpenOOD splits) + far_ood_datasets (list, optional): list of far OOD dataset classes must be subclass of torch.utils.data.Dataset. Defaults to MNIST, SVHN, Textures, Places365 (OpenOOD splits) eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. @@ -108,9 +120,11 @@ def __init__( self.test_alt = test_alt self.shift_severity = shift_severity - self.ood_dataset = SVHN self.shift_dataset = CIFAR10C + self.near_ood_datasets = near_ood_datasets or [] # List of near OOD dataset classes + self.far_ood_datasets = far_ood_datasets or [] # List of far OOD dataset classes + if (cutout is not None) + randaugment + int(auto_augment is not None) > 1: raise ValueError( "Only one data augmentation can be chosen at a time. Raise a " @@ -156,13 +170,16 @@ def __init__( ) if num_tta != 1: - self.test_transform = train_transform + self.test_transform = self.train_transform elif test_transform is not None: self.test_transform = test_transform else: self.test_transform = v2.Compose( [ v2.ToImage(), + v2.Resize(32), + v2.CenterCrop(32), + v2.CenterCrop(32), v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] @@ -178,8 +195,6 @@ def prepare_data(self) -> None: # coverage: ignore download=True, ) - if self.eval_ood: - self.ood_dataset(self.root, split="test", download=True) if self.eval_shift: self.shift_dataset( self.root, @@ -225,19 +240,42 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.test_transform, shift_severity=self.shift_severity, ) + if self.eval_ood: - self.ood = self.ood_dataset( - self.root, - split="test", - download=False, + self.test_ood, self.val_ood, near_default, far_default = get_ood_datasets( + root=self.root, + dataset_id="CIFAR10", transform=self.test_transform, ) + + if self.near_ood_datasets: + if not all(isinstance(ds, Dataset) for ds in self.near_ood_datasets): + raise TypeError("All entries in near_ood_datasets must be Dataset objects") + self.near_oods = self.near_ood_datasets + else: + self.near_oods = list(near_default.values()) + + if self.far_ood_datasets: + if not all(isinstance(ds, Dataset) for ds in self.far_ood_datasets): + raise TypeError("All entries in far_ood_datasets must be Dataset objects") + self.far_oods = self.far_ood_datasets + else: + self.far_oods = list(far_default.values()) + + for ds in [self.val_ood, *self.near_oods, *self.far_oods]: + if not hasattr(ds, "dataset_name"): + ds.dataset_name = ds.__class__.__name__.lower() + + self.near_ood_names = [ds.dataset_name for ds in self.near_oods] + self.far_ood_names = [ds.dataset_name for ds in self.far_oods] + if self.eval_shift: self.shift = self.shift_dataset( self.root, download=False, transform=self.test_transform, ) + if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") @@ -255,20 +293,19 @@ def train_dataloader(self) -> DataLoader: ) return self._data_loader(self.train, training=True, shuffle=True) - def test_dataloader(self) -> list[DataLoader]: - r"""Get test dataloaders. - - Return: - list[DataLoader]: test set for in distribution data, SVHN data, and/or CIFAR-10C data. - """ - dataloader = [self._data_loader(self.get_test_set(), training=False, shuffle=False)] + def test_dataloader(self): + loaders = [self._data_loader(self.get_test_set(), training=False)] if self.eval_ood: - dataloader.append(self._data_loader(self.get_ood_set(), training=False, shuffle=False)) + loaders.append(self._data_loader(self.get_test_ood_set(), training=False)) + + loaders.append(self._data_loader(self.get_val_ood_set(), training=False)) + + loaders.extend(self._data_loader(ds, training=False) for ds in self.get_near_ood_set()) + + loaders.extend(self._data_loader(ds, training=False) for ds in self.get_far_ood_set()) if self.eval_shift: - dataloader.append( - self._data_loader(self.get_shift_set(), training=False, shuffle=False) - ) - return dataloader + loaders.append(self._data_loader(self.get_shift_set(), training=False)) + return loaders def _get_train_data(self) -> ArrayLike: if self.val_split: @@ -279,3 +316,30 @@ def _get_train_targets(self) -> ArrayLike: if self.val_split: return np.array(self.train.dataset.targets)[self.train.indices] return np.array(self.train.targets) + + def get_indices(self): + idx = 0 + indices = {} + indices["test"] = [idx] + idx += 1 + if self.eval_ood: + indices["test_ood"] = [idx] + idx += 1 + indices["val_ood"] = [idx] + idx += 1 + n_near = len(self.near_oods) + indices["near_oods"] = list(range(idx, idx + n_near)) + idx += n_near + n_far = len(self.far_oods) + indices["far_oods"] = list(range(idx, idx + n_far)) + idx += n_far + else: + indices["test_ood"] = [] + indices["val_ood"] = [] + indices["near_oods"] = [] + indices["far_oods"] = [] + if self.eval_shift: + indices["shift"] = [idx] + else: + indices["shift"] = [] + return indices diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 58a47dbe..44a20663 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import Literal @@ -6,16 +7,23 @@ from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import DataLoader -from torchvision.datasets import CIFAR100, SVHN +from torch.utils.data import DataLoader, Dataset +from torchvision.datasets import CIFAR100 from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR100C +from torch_uncertainty.datasets.ood.utils import get_ood_datasets from torch_uncertainty.datasets.utils import create_train_val_split from torch_uncertainty.transforms import Cutout +logging.basicConfig( + level=logging.INFO, + format="%(message)s", +) +logging.getLogger("faiss").setLevel(logging.WARNING) + class CIFAR100DataModule(TUDataModule): num_classes = 100 @@ -46,13 +54,24 @@ def __init__( num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, + near_ood_datasets: list | None = None, + far_ood_datasets: list | None = None, ) -> None: """DataModule for CIFAR100. Args: root (str): Root directory of the datasets. + <<<<<<< HEAD + eval_ood (bool): Whether to evaluate out-of-distribution + performance. + near_ood_datasets (list, optional): list of near OOD dataset classes must be subclass of torch.utils.data.Dataset. Defaults to CIFAR-10, Tiny ImageNet (OpenOOD splits) + far_ood_datasets (list, optional): list of far OOD dataset classes must be subclass of torch.utils.data.Dataset. Defaults to MNIST, SVHN, Textures, Places365 (OpenOOD splits) + eval_shift (bool): Whether to evaluate on shifted data. Defaults to + ``False``. + ======= eval_ood (bool): Whether to evaluate out-of-distribution performance. eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. + >>>>>>> origin/dev batch_size (int): Number of samples per batch during training. eval_batch_size (int | None) : Number of samples per batch during evaluation (val and test). Set to batch_size if ``None``. Defaults to ``None``. @@ -95,11 +114,12 @@ def __init__( self.num_dataloaders = num_dataloaders self.dataset = CIFAR100 - self.ood_dataset = SVHN self.shift_dataset = CIFAR100C self.shift_severity = shift_severity + self.near_ood_datasets = near_ood_datasets or [] # List of near OOD dataset classes + self.far_ood_datasets = far_ood_datasets or [] # List of far OOD dataset classes if train_transform is not None: self.train_transform = train_transform else: @@ -145,13 +165,15 @@ def __init__( ) if num_tta != 1: - self.test_transform = train_transform + self.test_transform = self.train_transform elif test_transform is not None: self.test_transform = test_transform else: self.test_transform = v2.Compose( [ v2.ToImage(), + v2.Resize(32), + v2.CenterCrop(32), v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] @@ -161,13 +183,6 @@ def prepare_data(self) -> None: # coverage: ignore self.dataset(self.root, train=True, download=True) self.dataset(self.root, train=False, download=True) - if self.eval_ood: - self.ood_dataset( - self.root, - split="test", - download=True, - transform=self.test_transform, - ) if self.eval_shift: self.shift_dataset( self.root, @@ -206,12 +221,33 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: transform=self.test_transform, ) if self.eval_ood: - self.ood = self.ood_dataset( - self.root, - split="test", - download=False, + self.test_ood, self.val_ood, near_default, far_default = get_ood_datasets( + root=self.root, + dataset_id="CIFAR100", transform=self.test_transform, ) + + if self.near_ood_datasets: + if not all(isinstance(ds, Dataset) for ds in self.near_ood_datasets): + raise TypeError("All entries in near_ood_datasets must be Dataset objects") + self.near_oods = self.near_ood_datasets + else: + self.near_oods = list(near_default.values()) + + if self.far_ood_datasets: + if not all(isinstance(ds, Dataset) for ds in self.far_ood_datasets): + raise TypeError("All entries in far_ood_datasets must be Dataset objects") + self.far_oods = self.far_ood_datasets + else: + self.far_oods = list(far_default.values()) + + for ds in [self.val_ood, *self.near_oods, *self.far_oods]: + if not hasattr(ds, "dataset_name"): + ds.dataset_name = ds.__class__.__name__.lower() + + self.near_ood_names = [ds.dataset_name for ds in self.near_oods] + self.far_ood_names = [ds.dataset_name for ds in self.far_oods] + if self.eval_shift: self.shift = self.shift_dataset( self.root, @@ -223,33 +259,19 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: raise ValueError(f"Stage {stage} is not supported.") def train_dataloader(self) -> DataLoader: - """Get the training dataloader for CIFAR100. + r"""Get the training dataloader for CIFAR10. Return: DataLoader: CIFAR100 training dataloader. """ if self.num_dataloaders > 1: return self._data_loader( - AggregatedDataset(self.train, self.num_dataloaders), shuffle=True, training=True + AggregatedDataset(self.train, self.num_dataloaders), + shuffle=True, + training=True, ) return self._data_loader(self.train, training=True, shuffle=True) - def test_dataloader(self) -> list[DataLoader]: - r"""Get test dataloaders. - - Return: - list[DataLoader]: test set for in distribution data, SVHN data, and/or - CIFAR-100C data. - """ - dataloader = [self._data_loader(self.get_test_set(), training=False, shuffle=False)] - if self.eval_ood: - dataloader.append(self._data_loader(self.get_ood_set(), training=False, shuffle=False)) - if self.eval_shift: - dataloader.append( - self._data_loader(self.get_shift_set(), training=False, shuffle=False) - ) - return dataloader - def _get_train_data(self) -> ArrayLike: if self.val_split: return self.train.dataset.data[self.train.indices] @@ -259,3 +281,44 @@ def _get_train_targets(self) -> ArrayLike: if self.val_split: return np.array(self.train.dataset.targets)[self.train.indices] return np.array(self.train.targets) + + def test_dataloader(self): + loaders = [self._data_loader(self.get_test_set(), training=False)] + if self.eval_ood: + loaders.append(self._data_loader(self.get_test_ood_set(), training=False)) + + loaders.append(self._data_loader(self.get_val_ood_set(), training=False)) + + loaders.extend(self._data_loader(ds, training=False) for ds in self.get_near_ood_set()) + + loaders.extend(self._data_loader(ds, training=False) for ds in self.get_far_ood_set()) + if self.eval_shift: + loaders.append(self._data_loader(self.get_shift_set(), training=False)) + return loaders + + def get_indices(self): + idx = 0 + indices = {} + indices["test"] = [idx] + idx += 1 + if self.eval_ood: + indices["test_ood"] = [idx] + idx += 1 + indices["val_ood"] = [idx] + idx += 1 + n_near = len(self.near_oods) + indices["near_oods"] = list(range(idx, idx + n_near)) + idx += n_near + n_far = len(self.far_oods) + indices["far_oods"] = list(range(idx, idx + n_far)) + idx += n_far + else: + indices["test_ood"] = [] + indices["val_ood"] = [] + indices["near_oods"] = [] + indices["far_oods"] = [] + if self.eval_shift: + indices["shift"] = [idx] + else: + indices["shift"] = [] + return indices diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index b92dae23..7ba44e33 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -1,4 +1,4 @@ -import copy +import logging from pathlib import Path from typing import Literal @@ -7,8 +7,8 @@ from timm.data.auto_augment import rand_augment_transform from timm.data.mixup import Mixup from torch import nn -from torch.utils.data import DataLoader, Subset -from torchvision.datasets import DTD, SVHN, ImageNet, INaturalist +from torch.utils.data import DataLoader, Dataset +from torchvision.datasets import ImageFolder from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule @@ -17,25 +17,28 @@ ImageNetC, ImageNetO, ImageNetR, - OpenImageO, ) -from torch_uncertainty.datasets.utils import create_train_val_split +from torch_uncertainty.datasets.ood.utils import ( + FileListDataset, + download_and_extract_hf_dataset, + download_and_extract_splits_from_hf, + get_ood_datasets, +) from torch_uncertainty.utils import ( interpolation_modes_from_str, ) +logging.basicConfig( + level=logging.INFO, + format="%(message)s", +) +logging.getLogger("faiss").setLevel(logging.WARNING) + class ImageNetDataModule(TUDataModule): num_classes = 1000 num_channels = 3 test_datasets = ["r", "o", "a"] - ood_datasets = [ - "inaturalist", - "imagenet-o", - "svhn", - "textures", - "openimage-o", - ] training_task = "classification" mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) @@ -55,7 +58,6 @@ def __init__( postprocess_set: Literal["val", "test"] = "val", train_transform: nn.Module | None = None, test_transform: nn.Module | None = None, - ood_ds: str = "openimage-o", test_alt: str | None = None, procedure: str | None = None, train_size: int = 224, @@ -65,6 +67,8 @@ def __init__( num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, + near_ood_datasets: list | None = None, + far_ood_datasets: list | None = None, ) -> None: """DataModule for the ImageNet dataset. @@ -77,6 +81,8 @@ def __init__( eval_batch_size (int | None) : Number of samples per batch during evaluation (val and test). Set to batch_size if ``None``. Defaults to ``None``. eval_ood (bool): Whether to evaluate out-of-distribution performance. Defaults to ``False``. + near_ood_datasets (list, optional): list of near OOD dataset classes must be subclass of torch.utils.data.Dataset. Defaults to SSB-hard, NINCO (OpenOOD splits) + far_ood_datasets (list, optional): list of far OOD dataset classes must be subclass of torch.utils.data.Dataset. Defaults to iNaturalist, Textures, OpenImage-O (OpenOOD splits) eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. num_tta (int): Number of test-time augmentations (TTA). Defaults to ``1`` (no TTA). shift_severity (int): Severity of the shift. Defaults to ``1``. @@ -120,18 +126,21 @@ def __init__( ) self.eval_ood = eval_ood + self.num_tta = num_tta self.eval_shift = eval_shift self.shift_severity = shift_severity if val_split and not isinstance(val_split, float): val_split = Path(val_split) self.train_indices, self.val_indices = read_indices(val_split) self.val_split = val_split - self.ood_ds = ood_ds self.test_alt = test_alt self.interpolation = interpolation_modes_from_str(interpolation) + if self.test_alt is not None and eval_ood: + raise ValueError("For now test_alt argument is not supported when ood_eval=True.") + if test_alt is None: - self.dataset = ImageNet + self.dataset = None elif test_alt == "r": self.dataset = ImageNetR elif test_alt == "o": @@ -141,18 +150,9 @@ def __init__( else: raise ValueError(f"The alternative {test_alt} is not known.") - if ood_ds == "inaturalist": - self.ood_dataset = INaturalist - elif ood_ds == "imagenet-o": - self.ood_dataset = ImageNetO - elif ood_ds == "svhn": - self.ood_dataset = SVHN - elif ood_ds == "textures": - self.ood_dataset = DTD - elif ood_ds == "openimage-o": - self.ood_dataset = OpenImageO - else: - raise ValueError(f"The dataset {ood_ds} is not supported.") + self.near_ood_datasets = near_ood_datasets or [] + self.far_ood_datasets = far_ood_datasets or [] + self.shift_dataset = ImageNetC self.procedure = procedure @@ -213,8 +213,8 @@ def __init__( ] ) - if num_tta != 1: - self.test_transform = train_transform + if self.num_tta != 1: + self.test_transform = self.train_transform elif test_transform is not None: self.test_transform = test_transform else: @@ -229,122 +229,158 @@ def __init__( ) def _verify_splits(self, split: str) -> None: - if split not in list(self.root.iterdir()): - raise FileNotFoundError( - f"a {split} Imagenet split was not found in {self.root}," - f" make sure the folder contains a subfolder named {split}" - ) + split_dir = self.root / split + if not split_dir.is_dir(): + raise FileNotFoundError(f"a {split} Imagenet split was not found in {split_dir}") def prepare_data(self) -> None: # coverage: ignore if self.test_alt is not None: - self.data = self.dataset( - self.root, + self.test = self.dataset( + root=self.root, split="val", download=True, ) - if self.eval_ood: - if self.ood_ds == "inaturalist": - self.ood = self.ood_dataset( - self.root, - version="2021_valid", - download=True, - transform=self.test_transform, - ) - elif self.ood_ds != "textures": - self.ood = self.ood_dataset( - self.root, - split="test", - download=True, - transform=self.test_transform, - ) - else: - self.ood = self.ood_dataset( - self.root, - split="train", - download=True, - transform=self.test_transform, - ) if self.eval_shift: self.shift_dataset( - self.root, + root=self.root, download=True, transform=self.test_transform, shift_severity=self.shift_severity, ) def setup(self, stage: Literal["fit", "test"] | None = None) -> None: - if stage == "fit" or stage is None: + if stage not in (None, "fit", "test"): + raise ValueError(f"Stage {stage} is not supported.") + splits_base = download_and_extract_splits_from_hf(root=self.root) + + if stage == "fit": if self.test_alt is not None: raise ValueError("The test_alt argument is not supported for training.") - full = self.dataset( - self.root, - split="train", - transform=self.train_transform, + + # To change for more flexible splits later + self.data_dir = download_and_extract_hf_dataset("imagenet1k", self.root) + imagenet1k_splits = splits_base / "imagenet1k" + val_txt = imagenet1k_splits / "val_imagenet.txt" + self.val = FileListDataset( + root=self.data_dir, + list_file=val_txt, + transform=self.test_transform, ) - if self.val_split and isinstance(self.val_split, float): - self.train, self.val = create_train_val_split( - full, - self.val_split, - self.test_transform, + self.train = None + + if stage == "test": + if self.test_alt is not None: + self.test = self.dataset( + root=self.root, + split="val", + transform=self.test_transform, + download=False, ) - elif isinstance(self.val_split, Path): - self.train = Subset(full, self.train_indices) - # TODO: improve the performance - self.val = copy.deepcopy(Subset(full, self.val_indices)) - self.val.dataset.transform = self.test_transform else: - self.train = full - self.val = self.dataset( - self.root, - split="val", + self.data_dir = getattr( + self, "data_dir", download_and_extract_hf_dataset("imagenet1k", self.root) + ) + imagenet1k_splits = splits_base / "imagenet1k" + test_txt = imagenet1k_splits / "test_imagenet.txt" + + self.test = FileListDataset( + root=self.data_dir, + list_file=test_txt, transform=self.test_transform, ) - if stage == "test" or stage is None: - self.test = self.dataset( - self.root, - split="val", - transform=self.test_transform, - ) - if stage not in ["fit", "test", None]: - raise ValueError(f"Stage {stage} is not supported.") - if self.eval_ood: - if self.ood_ds == "inaturalist": - self.ood = self.ood_dataset( - self.root, - version="2021_valid", + if self.eval_ood: + self.test_ood, self.val_ood, near_default, far_default = get_ood_datasets( + root=self.root, + dataset_id="imagenet1k", transform=self.test_transform, ) - else: - self.ood = self.ood_dataset( - self.root, + + if self.near_ood_datasets: + if not all(isinstance(ds, Dataset) for ds in self.near_ood_datasets): + raise TypeError("All entries in near_ood_datasets must be Dataset objects") + self.near_oods = self.near_ood_datasets + else: + self.near_oods = list(near_default.values()) + + if self.far_ood_datasets: + if not all(isinstance(ds, Dataset) for ds in self.far_ood_datasets): + raise TypeError("All entries in far_ood_datasets must be Dataset objects") + self.far_oods = self.far_ood_datasets + else: + self.far_oods = list(far_default.values()) + + for ds in [self.val_ood, *self.near_oods, *self.far_oods]: + if not hasattr(ds, "dataset_name"): + ds.dataset_name = ds.__class__.__name__.lower() + + self.near_ood_names = [ds.dataset_name for ds in self.near_oods] + self.far_ood_names = [ds.dataset_name for ds in self.far_oods] + + if self.eval_shift: + self.shift = self.shift_dataset( + root=self.root, + download=False, transform=self.test_transform, - download=True, + shift_severity=self.shift_severity, ) - if self.eval_shift: - self.shift = self.shift_dataset( - self.root, - download=False, - transform=self.test_transform, - shift_severity=self.shift_severity, - ) + def train_dataloader(self) -> DataLoader: + # look for a train/ folder under the HF extraction root + train_dir = Path(self.data_dir) / "train" + if train_dir.is_dir(): + ds_train = ImageFolder(train_dir, transform=self.train_transform) + return self._data_loader(ds_train, training=True, shuffle=True) + raise RuntimeError( + "ImageNet training data not found under:\n" + f" {train_dir}\n" + "Please download the ILSVRC2012 train split manually from\n" + "https://www.image-net.org/download/ and unpack it under that folder." + ) - def test_dataloader(self) -> list[DataLoader]: - """Get the test dataloaders for ImageNet. + def val_dataloader(self) -> DataLoader: + return self._data_loader(self.val, training=False) - Return: - list[DataLoader]: ImageNet test set (in distribution data), OOD dataset test split - (out-of-distribution data), and/or ImageNetC data. - """ - dataloader = [self._data_loader(self.get_test_set(), training=False, shuffle=False)] + def test_dataloader(self): + loaders = [self._data_loader(self.get_test_set(), training=False)] + if self.eval_ood: + loaders.append(self._data_loader(self.get_test_ood_set(), training=False)) + + loaders.append(self._data_loader(self.get_val_ood_set(), training=False)) + + loaders.extend(self._data_loader(ds, training=False) for ds in self.get_near_ood_set()) + + loaders.extend(self._data_loader(ds, training=False) for ds in self.get_far_ood_set()) + if self.eval_shift: + loaders.append(self._data_loader(self.get_shift_set(), training=False)) + return loaders + + def get_indices(self): + idx = 0 + indices = {} + indices["test"] = [idx] + idx += 1 if self.eval_ood: - dataloader.append(self._data_loader(self.get_ood_set(), training=False, shuffle=False)) + indices["test_ood"] = [idx] + idx += 1 + indices["val_ood"] = [idx] + idx += 1 + n_near = len(self.near_oods) + indices["near_oods"] = list(range(idx, idx + n_near)) + idx += n_near + n_far = len(self.far_oods) + indices["far_oods"] = list(range(idx, idx + n_far)) + idx += n_far + else: + indices["test_ood"] = [] + indices["val_ood"] = [] + indices["near_oods"] = [] + indices["far_oods"] = [] if self.eval_shift: - dataloader.append( - self._data_loader(self.get_shift_set(), training=False, shuffle=False) - ) - return dataloader + indices["shift"] = [idx] + else: + indices["shift"] = [] + return indices def read_indices(path: Path) -> list[str]: # coverage: ignore diff --git a/torch_uncertainty/datamodules/classification/imagenet200.py b/torch_uncertainty/datamodules/classification/imagenet200.py new file mode 100644 index 00000000..8e0b748c --- /dev/null +++ b/torch_uncertainty/datamodules/classification/imagenet200.py @@ -0,0 +1,368 @@ +import logging +from pathlib import Path +from typing import Literal + +import torch +import yaml +from timm.data.auto_augment import rand_augment_transform +from timm.data.mixup import Mixup +from torch import nn +from torch.utils.data import DataLoader, Dataset +from torchvision.datasets import ImageFolder +from torchvision.transforms import v2 + +from torch_uncertainty.datamodules import TUDataModule +from torch_uncertainty.datasets.classification import ( + ImageNetA, + ImageNetC, + ImageNetO, + ImageNetR, +) +from torch_uncertainty.datasets.ood.utils import ( + FileListDataset, + download_and_extract_hf_dataset, + download_and_extract_splits_from_hf, + get_ood_datasets, +) +from torch_uncertainty.utils import ( + interpolation_modes_from_str, +) + +logging.basicConfig( + level=logging.INFO, + format="%(message)s", +) +logging.getLogger("faiss").setLevel(logging.WARNING) + + +class ImageNet200DataModule(TUDataModule): + num_classes = 200 + num_channels = 3 + test_datasets = ["r", "o", "a"] + training_task = "classification" + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + train_indices = None + val_indices = None + + def __init__( + self, + root: str | Path, + batch_size: int, + eval_batch_size: int | None = None, + eval_ood: bool = False, + eval_shift: bool = False, + shift_severity: int = 1, + val_split: float | Path | None = None, + postprocess_set: Literal["val", "test"] = "val", + test_alt: str | None = None, + procedure: str | None = None, + train_size: int = 224, + interpolation: str = "bilinear", + basic_augment: bool = True, + rand_augment_opt: str | None = None, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + near_ood_datasets: list | None = None, + far_ood_datasets: list | None = None, + ) -> None: + """DataModule for the ImageNet200 dataset. + + This datamodule uses ImageNet200 as In-distribution dataset, OpenImage-O, INaturalist, + ImageNet-0, SVHN or DTD as Out-of-distribution dataset and ImageNet-C as shifted dataset. + + Args: + root (str): Root directory of the datasets. + batch_size (int): Number of samples per batch during training. + eval_batch_size (int | None) : Number of samples per batch during evaluation (val + and test). Set to batch_size if None. Defaults to None. + eval_ood (bool): Whether to evaluate out-of-distribution performance. Defaults to ``False``. + near_ood_datasets (list, optional): list of near OOD dataset classes must be subclass of torch.utils.data.Dataset. Defaults to SSB-hard, NINCO (OpenOOD splits) + far_ood_datasets (list, optional): list of far OOD dataset classes must be subclass of torch.utils.data.Dataset. Defaults to iNaturalist, Textures, OpenImage-O (OpenOOD splits) + eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. + shift_severity (int): Severity of the shift. Defaults to ``1``. + val_split (float or Path): Share of samples to use for validation + or path to a yaml file containing a list of validation images + ids. Defaults to ``0.0``. + postprocess_set (str, optional): The post-hoc calibration dataset to + use for the post-processing method. Defaults to ``val``. + ood_ds (str): Which out-of-distribution dataset to use. Defaults to + ``"openimage-o"``. + test_alt (str): Which test set to use. Defaults to ``None``. + procedure (str): Which procedure to use. Defaults to ``None``. + train_size (int): Size of training images. Defaults to ``224``. + interpolation (str): Interpolation method for the Resize Crops. + Defaults to ``"bilinear"``. + basic_augment (bool): Whether to apply base augmentations. Defaults to + ``True``. + rand_augment_opt (str): Which RandAugment to use. Defaults to ``None``. + num_workers (int): Number of workers to use for data loading. Defaults + to ``1``. + pin_memory (bool): Whether to pin memory. Defaults to ``True``. + persistent_workers (bool): Whether to use persistent workers. Defaults + to ``True``. + """ + super().__init__( + root=Path(root), + batch_size=batch_size, + eval_batch_size=eval_batch_size, + val_split=val_split, + postprocess_set=postprocess_set, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + self.eval_ood = eval_ood + self.eval_shift = eval_shift + self.shift_severity = shift_severity + if val_split and not isinstance(val_split, float): + val_split = Path(val_split) + self.train_indices, self.val_indices = read_indices(val_split) + self.val_split = val_split + self.test_alt = test_alt + self.interpolation = interpolation_modes_from_str(interpolation) + + if self.test_alt is not None and eval_ood: + raise ValueError("For now test_alt argument is not supported when ood_eval=True.") + + if test_alt is None: + self.dataset = None + elif test_alt == "r": + self.dataset = ImageNetR + elif test_alt == "o": + self.dataset = ImageNetO + elif test_alt == "a": + self.dataset = ImageNetA + else: + raise ValueError(f"The alternative {test_alt} is not known.") + + self.near_ood_datasets = near_ood_datasets or [] + self.far_ood_datasets = far_ood_datasets or [] + + self.shift_dataset = ImageNetC + + self.procedure = procedure + + if basic_augment: + basic_transform = v2.Compose( + [ + v2.RandomResizedCrop(train_size, interpolation=self.interpolation), + v2.RandomHorizontalFlip(), + ] + ) + else: + basic_transform = nn.Identity() + + if self.procedure is None: + if rand_augment_opt is not None: + main_transform = rand_augment_transform(rand_augment_opt, {}) + else: + main_transform = nn.Identity() + elif self.procedure == "ViT": + train_size = 224 + main_transform = v2.Compose( + [ + Mixup(mixup_alpha=0.2, cutmix_alpha=1.0), + rand_augment_transform("rand-m9-n2-mstd0.5", {}), + ] + ) + elif self.procedure == "A3": + train_size = 160 + main_transform = rand_augment_transform("rand-m6-mstd0.5-inc1", {}) + else: + raise ValueError("The procedure is unknown") + + self.train_transform = v2.Compose( + [ + v2.ToImage(), + basic_transform, + main_transform, + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) + + self.test_transform = v2.Compose( + [ + v2.ToImage(), + v2.Resize(256, interpolation=self.interpolation), + v2.CenterCrop(224), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), + ] + ) + + def _verify_splits(self, split: str) -> None: + if split not in list(self.root.iterdir()): + raise FileNotFoundError( + f"a {split} Imagenet split was not found in {self.root}," + f" make sure the folder contains a subfolder named {split}" + ) + + def prepare_data(self) -> None: # coverage: ignore + if self.test_alt is not None: + self.test = self.dataset( + self.root, + split="val", + download=True, + ) + if self.eval_shift: + self.shift_dataset( + self.root, + download=True, + transform=self.test_transform, + shift_severity=self.shift_severity, + ) + + def setup(self, stage: Literal["fit", "test"] | None = None) -> None: + if stage not in (None, "fit", "test"): + raise ValueError(f"Stage {stage} is not supported.") + splits_base = download_and_extract_splits_from_hf(root=self.root) + if stage == "fit": + if self.test_alt is not None: + raise ValueError("The test_alt argument is not supported for training.") + + # To change for more flexible splits later + self.data_dir = download_and_extract_hf_dataset("imagenet1k", self.root) + imagenet1k_splits = splits_base / "imagenet200" + val_txt = imagenet1k_splits / "val_imagenet200.txt" + self.val = FileListDataset( + root=self.data_dir, + list_file=val_txt, + transform=self.test_transform, + ) + self.train = None + + if stage == "test": + if self.test_alt is not None: + self.test = self.dataset( + self.root, + split="val", + transform=self.test_transform, + download=False, + ) + else: + self.data_dir = getattr( + self, "data_dir", download_and_extract_hf_dataset("imagenet1k", self.root) + ) + imagenet1k_splits = splits_base / "imagenet200" + test_txt = imagenet1k_splits / "test_imagenet200.txt" + self.test = FileListDataset( + root=self.data_dir, + list_file=test_txt, + transform=self.test_transform, + ) + + if self.eval_ood: + self.test_ood, self.val_ood, near_default, far_default = get_ood_datasets( + root=self.root, + dataset_id="imagenet200", + transform=self.test_transform, + ) + + if self.near_ood_datasets: + if not all(isinstance(ds, Dataset) for ds in self.near_ood_datasets): + raise TypeError("All entries in near_ood_datasets must be Dataset objects") + self.near_oods = self.near_ood_datasets + else: + self.near_oods = list(near_default.values()) + + if self.far_ood_datasets: + if not all(isinstance(ds, Dataset) for ds in self.far_ood_datasets): + raise TypeError("All entries in far_ood_datasets must be Dataset objects") + self.far_oods = self.far_ood_datasets + else: + self.far_oods = list(far_default.values()) + + for ds in [self.val_ood, *self.near_oods, *self.far_oods]: + if not hasattr(ds, "dataset_name"): + ds.dataset_name = ds.__class__.__name__.lower() + + self.near_ood_names = [ds.dataset_name for ds in self.near_oods] + self.far_ood_names = [ds.dataset_name for ds in self.far_oods] + + if stage not in ["fit", "test", None]: + raise ValueError(f"Stage {stage} is not supported.") + + if self.eval_shift: + self.shift = self.shift_dataset( + self.root, + download=False, + transform=self.test_transform, + shift_severity=self.shift_severity, + ) + + def train_dataloader(self) -> DataLoader: + # look for a train/ folder under the HF extraction root + train_dir = Path(self.data_dir) / "train" + if train_dir.is_dir(): + ds_train = ImageFolder(train_dir, transform=self.train_transform) + return self._data_loader(ds_train, training=True, shuffle=True) + raise RuntimeError( + "ImageNet training data not found under:\n" + f" {train_dir}\n" + "Please download the ILSVRC2012 train split manually from\n" + "https://www.image-net.org/download/ and unpack it under that folder." + ) + + def val_dataloader(self) -> DataLoader: + return self._data_loader(self.val, training=False) + + def test_dataloader(self): + loaders = [self._data_loader(self.get_test_set(), training=False)] + if self.eval_ood: + loaders.append(self._data_loader(self.get_test_ood_set(), training=False)) + + loaders.append(self._data_loader(self.get_val_ood_set(), training=False)) + + loaders.extend(self._data_loader(ds, training=False) for ds in self.get_near_ood_set()) + + loaders.extend(self._data_loader(ds, training=False) for ds in self.get_far_ood_set()) + if self.eval_shift: + loaders.append(self._data_loader(self.get_shift_set(), training=False)) + return loaders + + def get_indices(self): + idx = 0 + indices = {} + indices["test"] = [idx] + idx += 1 + if self.eval_ood: + indices["test_ood"] = [idx] + idx += 1 + indices["val_ood"] = [idx] + idx += 1 + n_near = len(self.near_oods) + indices["near_oods"] = list(range(idx, idx + n_near)) + idx += n_near + n_far = len(self.far_oods) + indices["far_oods"] = list(range(idx, idx + n_far)) + idx += n_far + else: + indices["test_ood"] = [] + indices["val_ood"] = [] + indices["near_oods"] = [] + indices["far_oods"] = [] + if self.eval_shift: + indices["shift"] = [idx] + else: + indices["shift"] = [] + return indices + + +def read_indices(path: Path) -> list[str]: # coverage: ignore + """Read a file and return its lines as a list. + + Args: + path (Path): Path to the file. + + Returns: + list[str]: list of filenames. + """ + if not path.is_file(): + raise ValueError(f"{path} is not a file.") + with path.open("r") as f: + indices = yaml.safe_load(f) + return indices["train"], indices["val"] diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index 748ab171..29f54e28 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -118,7 +118,7 @@ def __init__( ) if num_tta != 1: - self.test_transform = train_transform + self.test_transform = self.train_transform elif test_transform is not None: self.test_transform = test_transform else: diff --git a/torch_uncertainty/datamodules/classification/sst2.py b/torch_uncertainty/datamodules/classification/sst2.py new file mode 100644 index 00000000..33d514b7 --- /dev/null +++ b/torch_uncertainty/datamodules/classification/sst2.py @@ -0,0 +1,219 @@ +import os + +import torch + +os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + +from datasets import DownloadConfig, load_dataset +from torch.utils.data import Dataset +from transformers import AutoTokenizer + +from torch_uncertainty.datamodules.abstract import TUDataModule + + +class HFTupleDataset(Dataset): + def __init__(self, hf_ds, name=None): + """Initialize the dataset wrapper for HuggingFace datasets.""" + self.ds = hf_ds + self.dataset_name = (name or "dataset").lower().replace(" ", "_") + + def __len__(self): + """Return the length of the dataset.""" + return len(self.ds) + + def __getitem__(self, idx): + """Get an item from the dataset.""" + item = self.ds[idx] + x = {"input_ids": item["input_ids"], "attention_mask": item["attention_mask"]} + y = item.get("label", 0) + return x, torch.tensor(int(y), dtype=torch.long) + + +class Sst2DataModule(TUDataModule): + num_classes = 2 + training_task = "classification" + num_channels = 1 + input_shape = None # text + + def __init__( + self, + model_name="bert-base-uncased", + max_len=128, + batch_size=32, + eval_batch_size=None, + num_tta=1, + num_workers=4, + pin_memory=True, + persistent_workers=True, + local_files_only=False, + eval_ood: bool = True, + ): + """Initialize the SST-2 data module.""" + super().__init__( + root=".", + batch_size=batch_size, + eval_batch_size=eval_batch_size, + val_split=None, + num_tta=num_tta, + postprocess_set="val", + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + self.model_name = model_name + self.max_len = max_len + self.local_files_only = local_files_only + self.eval_ood = eval_ood + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, use_fast=True, local_files_only=local_files_only + ) + + self.near_oods, self.far_oods = [], [] + self.test_ood, self.val_ood = None, None + + def _wrap_tok_set(self, tokds, name): + keep = [c for c in ["input_ids", "attention_mask", "label"] if c in tokds.column_names] + tokds.set_format(type="torch", columns=keep) + return HFTupleDataset(tokds, name=name) + + def _tok_single(self, raw, field, name): + def tok(b): + return self.tokenizer( + b[field], max_length=self.max_len, truncation=True, padding="max_length" + ) + + cols_rm = [c for c in raw.column_names if c not in ("label", field)] + tokds = raw.map(tok, batched=True, remove_columns=cols_rm) + return self._wrap_tok_set(tokds, name) + + def _tok_pair(self, raw, field1, field2, name): + def tok(b): + return self.tokenizer( + b[field1], b[field2], max_length=self.max_len, truncation=True, padding="max_length" + ) + + cols_rm = [c for c in raw.column_names if c not in ("label", field1, field2)] + tokds = raw.map(tok, batched=True, remove_columns=cols_rm) + return self._wrap_tok_set(tokds, name) + + def _tok_wmt16_en(self, raw, name): + def tok(b): + en = [t.get("en", "") for t in b["translation"]] + return self.tokenizer( + en, max_length=self.max_len, truncation=True, padding="max_length" + ) + + cols_rm = [c for c in raw.column_names if c not in ("label", "translation")] + tokds = raw.map(tok, batched=True, remove_columns=cols_rm) + return self._wrap_tok_set(tokds, name) + + def prepare_data(self): + dl = DownloadConfig(local_files_only=self.local_files_only) + # Cache SST-2 + load_dataset("glue", "sst2", download_config=dl) + + if not self.eval_ood: + return + + load_dataset("yelp_polarity", split="test", download_config=dl) + load_dataset("amazon_polarity", split="test", download_config=dl) + + load_dataset("ag_news", split="test", download_config=dl) + load_dataset("SetFit/20_newsgroups", split="test", download_config=dl) + load_dataset("SetFit/TREC-QC", split="test", download_config=dl) + load_dataset("glue", "mnli", split="validation_mismatched", download_config=dl) + load_dataset("glue", "rte", split="validation", download_config=dl) + load_dataset("wmt16", "ro-en", split="test", download_config=dl) + + def setup(self, stage=None): + dl = DownloadConfig(local_files_only=self.local_files_only) + + ds = load_dataset("glue", "sst2", download_config=dl) + + def tok_id(b): + return self.tokenizer( + b["sentence"], max_length=self.max_len, truncation=True, padding="max_length" + ) + + cols_rm = [c for c in ["sentence", "idx"] if c in ds["train"].column_names] + tokds = ds.map(tok_id, batched=True, remove_columns=cols_rm) + tokds.set_format( + type="torch", + columns=[ + c + for c in ["input_ids", "attention_mask", "label"] + if c in tokds["train"].column_names + ], + ) + + if stage in (None, "fit"): + self.train = HFTupleDataset(tokds["train"], name="sst2_train") + self.val = HFTupleDataset(tokds["validation"], name="sst2_val") + + if stage in (None, "test"): + self.test = HFTupleDataset(tokds["validation"], name="sst2_test") + self.near_oods, self.far_oods = [], [] + + if self.eval_ood: + # -------- Near OOD -------- + yelp_raw = load_dataset("yelp_polarity", split="test", download_config=dl) + amazon_raw = load_dataset("amazon_polarity", split="test", download_config=dl) + + self.near_oods.append(self._tok_single(yelp_raw, "text", "yelp_polarity")) + self.near_oods.append(self._tok_single(amazon_raw, "content", "amazon_polarity")) + + # -------- Far OOD -------- + ag_raw = load_dataset("ag_news", split="test", download_config=dl) + n20_raw = load_dataset("SetFit/20_newsgroups", split="test", download_config=dl) + trec_raw = load_dataset("SetFit/TREC-QC", split="test", download_config=dl) + mnli_raw = load_dataset( + "glue", "mnli", split="validation_mismatched", download_config=dl + ) + rte_raw = load_dataset("glue", "rte", split="validation", download_config=dl) + wmt_raw = load_dataset("wmt16", "ro-en", split="test", download_config=dl) + + self.far_oods.append(self._tok_single(ag_raw, "text", "ag_news")) + self.far_oods.append(self._tok_single(n20_raw, "text", "20_newsgroups")) + self.far_oods.append(self._tok_single(trec_raw, "text", "trec_qc")) + self.far_oods.append(self._tok_pair(mnli_raw, "premise", "hypothesis", "mnli_mm")) + self.far_oods.append(self._tok_pair(rte_raw, "sentence1", "sentence2", "rte")) + self.far_oods.append(self._tok_wmt16_en(wmt_raw, "wmt16_ro_en_en")) + + self.test_ood = self.test + self.val_ood = None + + def train_dataloader(self): + return self._data_loader(self.train, training=True, shuffle=True) + + def val_dataloader(self): + return self._data_loader(self.val, training=False) + + def test_dataloader(self): + loaders = [self._data_loader(self.test, training=False)] + if self.eval_ood: + loaders.append(self._data_loader(self.test_ood, training=False)) + # no val_ood loader + loaders.extend(self._data_loader(ds, training=False) for ds in self.near_oods) + loaders.extend(self._data_loader(ds, training=False) for ds in self.far_oods) + return loaders + + def get_test_set(self): + return self.test + + def get_indices(self): + idx = 0 + out = {"test": [idx]} + idx += 1 + if self.eval_ood: + out["test_ood"] = [idx] + idx += 1 + out["val_ood"] = [] # kept empty + out["near_oods"] = list(range(idx, idx + len(self.near_oods))) + idx += len(self.near_oods) + out["far_oods"] = list(range(idx, idx + len(self.far_oods))) + idx += len(self.far_oods) + else: + out |= {"test_ood": [], "val_ood": [], "near_oods": [], "far_oods": []} + out["shift"] = [] + return out diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index ecd7b278..02990029 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -153,7 +153,7 @@ def __init__( ) if num_tta != 1: - self.test_transform = train_transform + self.test_transform = self.train_transform elif test_transform is not None: self.test_transform = test_transform else: diff --git a/torch_uncertainty/datasets/classification/imagenet/base.py b/torch_uncertainty/datasets/classification/imagenet/base.py index 405b04c3..773a0756 100644 --- a/torch_uncertainty/datasets/classification/imagenet/base.py +++ b/torch_uncertainty/datasets/classification/imagenet/base.py @@ -1,3 +1,4 @@ +import hashlib import json import logging from collections.abc import Callable @@ -18,8 +19,11 @@ class ImageNetVariation(ImageFolder): dataset_name: str root_appendix: str - wnid_to_idx_url = "https://raw.githubusercontent.com/torch-uncertainty/dataset-metadata/main/classification/imagenet/classes.json" - wnid_to_idx_md5 = "1bcf467b49f735dbeb745249eae6b133" # avoid replacement attack + wnid_to_idx_url = ( + "https://raw.githubusercontent.com/torch-uncertainty/dataset-metadata/main/" + "classification/imagenet/classes.json" + ) + wnid_to_idx_md5 = "1bcf467b49f735dbeb745249eae6b133" def __init__( self, @@ -29,26 +33,12 @@ def __init__( target_transform: Callable | None = None, download: bool = False, ) -> None: - """Virtual base class for ImageNet variations. - - Args: - root (str | Path): Root directory of the datasets. - split (str, optional): For API consistency. Defaults to ``None``. - transform (callable, optional): A function/transform that takes in - a PIL image and returns a transformed version. E.g, - ``transforms.RandomCrop``. Defaults to ``None``. - target_transform (callable, optional): A function/transform that - takes in the target and transforms it. Defaults to ``None``. - download (bool, optional): If ``True``, downloads the dataset from the - internet and puts it in root directory. If dataset is already - downloaded, it is not downloaded again. Defaults to ``False``. - """ + self.root = Path(root).expanduser().resolve() + self.split = split + if download: self.download() - self.root = Path(root) - self.split = split - if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted. You can use download=True to download it." @@ -59,32 +49,88 @@ def __init__( transform=transform, target_transform=target_transform, ) - self._repair_dataset() + # ---------------- helpers ---------------- + + def _md5_of(self, path: Path, chunk_size: int = 1 << 22) -> str: + h = hashlib.md5() # noqa: S324 + with path.open("rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + h.update(chunk) + return h.hexdigest() + def _check_integrity(self) -> bool: - """Check the integrity of the dataset(s).""" + data_root = self.root / self.dataset_name + if data_root.is_dir(): + logging.info("[integrity] extracted dir present: %s -> OK", data_root) + return True + if isinstance(self.filename, str): - return check_integrity( - self.root / Path(self.filename), - self.tgz_md5, - ) - if isinstance(self.filename, list): # ImageNet-C - integrity: bool = True - for filename, md5 in zip(self.filename, self.tgz_md5, strict=True): - integrity *= check_integrity( - self.root / self.root_appendix / filename, - md5, + p = (self.root / Path(self.filename)).resolve() + if p.exists(): + actual = self._md5_of(p) + logging.info( + "[integrity] %s | MD5 got: %s | expected: %s | %s", + p, + actual, + self.tgz_md5, + "OK" if actual == self.tgz_md5 else "FAIL", ) - return integrity + else: + logging.info("[integrity] missing archive: %s", p) + return check_integrity(p, self.tgz_md5) + + if isinstance(self.filename, list): + ok = True + for filename, md5 in zip(self.filename, self.tgz_md5, strict=True): + p_root = (self.root / filename).resolve() + if p_root.exists(): + actual = self._md5_of(p_root) + logging.info( + "[integrity] %s | MD5 got: %s | expected: %s | %s", + p_root, + actual, + md5, + "OK" if actual == md5 else "FAIL", + ) + else: + logging.info("[integrity] missing archive at /: %s", p_root) + ok *= check_integrity(p_root, md5) + return bool(ok) + raise ValueError("filename must be str or list") def download(self) -> None: - """Download and extract dataset.""" + data_root = self.root / self.dataset_name + if data_root.is_dir(): + logging.info("[download] extracted dir present -> skipping download/extract.") + return + if self._check_integrity(): logging.info("Files already downloaded and verified") + if isinstance(self.filename, list): + logging.info("[download] extracting existing valid archives...") + for url, filename, md5 in zip(self.url, self.filename, self.tgz_md5, strict=True): + p_root = (self.root / filename).resolve() + if p_root.exists() and check_integrity(p_root, md5): + logging.info("[download] Extracting: %s", p_root) + download_and_extract_archive( + url, + self.root, + extract_root=self.root / self.root_appendix, + filename=filename, + md5=md5, + ) return + if isinstance(self.filename, str): + p = (self.root / Path(self.filename)).resolve() + if p.exists(): + actual = self._md5_of(p) + logging.info( + "[download] existing %s | MD5 got: %s | expected: %s", p, actual, self.tgz_md5 + ) download_and_extract_archive( self.url, self.root, @@ -92,28 +138,38 @@ def download(self) -> None: filename=self.filename, md5=self.tgz_md5, ) - elif isinstance(self.filename, list): # ImageNet-C - for url, filename, md5 in zip(self.url, self.filename, self.tgz_md5, strict=True): - # Check that this particular file is not already downloaded - if not check_integrity(self.root / self.root_appendix / Path(filename), md5): - download_and_extract_archive( - url, - self.root, - extract_root=self.root / self.root_appendix, - filename=filename, - md5=md5, - ) + return + + for url, filename, md5 in zip(self.url, self.filename, self.tgz_md5, strict=True): + p_root = (self.root / filename).resolve() + if p_root.exists(): + actual = self._md5_of(p_root) + logging.info( + "[download] existing %s | MD5 got: %s | expected: %s", p_root, actual, md5 + ) + if not check_integrity(p_root, md5): + logging.info("[download] fetching -> %s -> %s", url, p_root) + download_and_extract_archive( + url, + self.root, + extract_root=self.root / self.root_appendix, + filename=filename, + md5=md5, + ) + elif not data_root.is_dir(): + logging.info("[download] extracting valid existing: %s", p_root) + download_and_extract_archive( + url, + self.root, + extract_root=self.root / self.root_appendix, + filename=filename, + md5=md5, + ) def _repair_dataset(self) -> None: - """Download the wnid_to_idx.txt file and to get the correct targets.""" path = self.root / "classes.json" if not check_integrity(path, self.wnid_to_idx_md5): - download_url( - self.wnid_to_idx_url, - self.root, - "classes.json", - self.wnid_to_idx_md5, - ) + download_url(self.wnid_to_idx_url, self.root, "classes.json", self.wnid_to_idx_md5) with (self.root / "classes.json").open() as file: self.wnid_to_idx = json.load(file) diff --git a/torch_uncertainty/datasets/classification/imagenet/imagenet_a.py b/torch_uncertainty/datasets/classification/imagenet/imagenet_a.py index 6bb8f997..aa978fb6 100644 --- a/torch_uncertainty/datasets/classification/imagenet/imagenet_a.py +++ b/torch_uncertainty/datasets/classification/imagenet/imagenet_a.py @@ -6,6 +6,7 @@ class ImageNetA(ImageNetVariation): filename = "imagenet-a.tar" tgz_md5 = "c3e55429088dc681f30d81f4726b6595" dataset_name = "imagenet-a" + root_appendix = "imagenet-a" def __init__(self, **kwargs) -> None: """Initializes the ImageNetA dataset class. diff --git a/torch_uncertainty/datasets/classification/imagenet/imagenet_c.py b/torch_uncertainty/datasets/classification/imagenet/imagenet_c.py index f3827a5c..e05fa3de 100644 --- a/torch_uncertainty/datasets/classification/imagenet/imagenet_c.py +++ b/torch_uncertainty/datasets/classification/imagenet/imagenet_c.py @@ -1,14 +1,10 @@ +from pathlib import Path + from .base import ImageNetVariation class ImageNetC(ImageNetVariation): - """The corrupted ImageNet-C dataset. - - References: - Benchmarking neural network robustness to common corruptions and - perturbations. Dan Hendrycks and Thomas Dietterich. - In ICLR, 2019. - """ + """The corrupted ImageNet-C dataset.""" url = [ "https://zenodo.org/record/2235448/files/blur.tar", @@ -17,13 +13,7 @@ class ImageNetC(ImageNetVariation): "https://zenodo.org/record/2235448/files/noise.tar", "https://zenodo.org/record/2235448/files/weather.tar", ] - filename = [ - "blur.tar", - "digital.tar", - "extra.tar", - "noise.tar", - "weather.tar", - ] + filename = ["blur.tar", "digital.tar", "extra.tar", "noise.tar", "weather.tar"] tgz_md5 = [ "2d8e81fdd8e07fef67b9334fa635e45c", "89157860d7b10d5797849337ca2e5c03", @@ -35,21 +25,25 @@ class ImageNetC(ImageNetVariation): root_appendix = "imagenet-c" def __init__(self, **kwargs) -> None: - """Initializes the ImageNetC dataset class. - - This is a subclass of ImageNetVariation that supports additional keyword arguments. - - Args: - kwargs: Additional keyword arguments passed to the superclass, including: - - - root (str): Root directory of the datasets. - - split (str, optional): For API consistency. Defaults to ``None``. - - transform (callable, optional): A function/transform that takes in a PIL image and - returns a transformed version. E.g., transforms.RandomCrop. Defaults to ``None``. - - target_transform (callable, optional): A function/transform that takes in the target - and transforms it. Defaults to ``None``. - - download (bool, optional): If ``True``, downloads the dataset from the internet - and puts it in the root directory. If the dataset is already downloaded, it is - not downloaded again. Defaults to ``False``. - """ + severity = kwargs.pop("shift_severity", 1) + try: + severity = int(severity) + except Exception as e: + raise ValueError(f"shift_severity must be an int in [1..5], got {severity!r}") from e + if severity not in (1, 2, 3, 4, 5): + raise ValueError(f"shift_severity must be in [1..5], got {severity}") + super().__init__(**kwargs) + + sev_str = str(severity) + filtered = [(p, t) for (p, t) in self.samples if Path(p).parts[-3] == sev_str] + if not filtered: + raise RuntimeError( + f"ImageNet-C: no samples matched shift_severity={severity}. " + "Check extraction under /imagenet-c///..." + ) + + self.samples = filtered + self.imgs = filtered + self.targets = [t for _, t in filtered] + self.shift_severity = severity diff --git a/torch_uncertainty/datasets/classification/imagenet/imagenet_o.py b/torch_uncertainty/datasets/classification/imagenet/imagenet_o.py index 8ff30020..7d47695e 100644 --- a/torch_uncertainty/datasets/classification/imagenet/imagenet_o.py +++ b/torch_uncertainty/datasets/classification/imagenet/imagenet_o.py @@ -6,6 +6,7 @@ class ImageNetO(ImageNetVariation): filename = "imagenet-o.tar" tgz_md5 = "86bd7a50c1c4074fb18fc5f219d6d50b" dataset_name = "imagenet-o" + root_appendix = "imagenet-o" def __init__(self, **kwargs) -> None: """Initializes the ImageNetO dataset class. diff --git a/torch_uncertainty/datasets/classification/imagenet/imagenet_r.py b/torch_uncertainty/datasets/classification/imagenet/imagenet_r.py index de7e49a2..baace19d 100644 --- a/torch_uncertainty/datasets/classification/imagenet/imagenet_r.py +++ b/torch_uncertainty/datasets/classification/imagenet/imagenet_r.py @@ -6,6 +6,7 @@ class ImageNetR(ImageNetVariation): filename = "imagenet-r.tar" tgz_md5 = "a61312130a589d0ca1a8fca1f2bd3337" dataset_name = "imagenet-r" + root_appendix = "imagenet-r" def __init__(self, **kwargs) -> None: """Initializes the ImageNetR dataset class. diff --git a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py index ed930a66..2438d239 100644 --- a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py +++ b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py @@ -1,3 +1,7 @@ +import logging +import os +import urllib.request +import zipfile from collections import defaultdict from collections.abc import Callable from pathlib import Path @@ -8,6 +12,8 @@ from PIL import Image from torch.utils.data import Dataset +logger = logging.getLogger(__name__) + class TinyImageNet(Dataset): def __init__( @@ -16,10 +22,21 @@ def __init__( split: Literal["train", "val", "test"] = "train", transform: Callable | None = None, target_transform: Callable | None = None, + download: bool = False, # added download attribute ) -> None: """Inspired by https://gist.github.com/z-a-f/b862013c0dc2b540cf96a123a6766e54.""" self.root = Path(root) / "tiny-imagenet-200" + if download and not self.root.exists(): + url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" + zip_path = Path(root) / "tiny-imagenet-200.zip" + logger.info("Downloading tiny-imagenet-200 dataset...") + urllib.request.urlretrieve(url, zip_path) # noqa: S310 + logger.info("Extracting dataset...") + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(root) + zip_path.unlink() + if split not in ["train", "val", "test"]: raise ValueError(f"Split {split} is not supported.") @@ -115,6 +132,7 @@ def _make_paths(self) -> list[tuple[Path, int]]: paths.append((fname, label_id)) else: # self.split == "test": - test_path = Path(self.root / "test") - paths = [test_path / x for x in test_path.iterdir()] + test_path = self.root / "test" / "images" + paths = [(test_path / x, -1) for x in os.listdir(test_path)] # noqa: PTH208 + return paths diff --git a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py index 8638ceb3..40fedf4c 100644 --- a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py +++ b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet_c.py @@ -58,8 +58,12 @@ def __init__( ``transforms.RandomCrop``. Defaults to ``None``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. Defaults to ``None``. - subset (str): The subset to use, one of ``all`` or the keys in - ``cifarc_subsets``. + subset (str): The target corruption to use type should be one of + {`brightness`, `contrast`, `defocus_blur`, `elastic_transform`, + `fog`, `frost`, `gaussian_blur`, `gaussian_noise`, `glass_blur`, + `impulse_noise`, `jpeg_compression`, `motion_blur`, `pixelate`, + `saturate`, `shot_noise`, `snow`, `spatter`, `speckle_noise`, + `zoom_blur`}. Defaults to ``all`` for all corruptions. shift_severity (int): The shift_severity of the corruption, between ``1`` and ``5``. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already diff --git a/torch_uncertainty/datasets/ood/utils.py b/torch_uncertainty/datasets/ood/utils.py new file mode 100644 index 00000000..9defcada --- /dev/null +++ b/torch_uncertainty/datasets/ood/utils.py @@ -0,0 +1,386 @@ +import logging +import os +import tarfile +import urllib.request +import zipfile +from pathlib import Path + +from huggingface_hub import hf_hub_download +from PIL import Image +from torch.utils.data import Dataset + +logger = logging.getLogger(__name__) + + +def _safe_extract(tar: tarfile.TarFile, path: Path): + """Safely extract tar members into `path`, preventing path traversal.""" + base = path.resolve() + for member in tar.getmembers(): + member_path = (path / member.name).resolve() + if not str(member_path).startswith(str(base) + os.sep): + raise RuntimeError(f"Unsafe path in tar archive: {member.name}") + tar.extract(member, path) + + +def _safe_extract_zip(zf: zipfile.ZipFile, path: Path): + """Safely extract zip members into `path`, preventing path traversal.""" + base = path.resolve() + for member in zf.namelist(): + member_path = (path / member).resolve() + if not str(member_path).startswith(str(base) + os.sep): + raise RuntimeError(f"Unsafe path in zip archive: {member}") + zf.extract(member, path) + + +class FileListDataset(Dataset): + def __init__(self, root: str | Path, list_file: str | Path, name=None, transform=None): + self.root = Path(root) + self.transform = transform + self.dataset_name = name + self.samples = [] + with Path(list_file).open() as f: + for line in f: + path_str, lbl_str = line.strip().rsplit(maxsplit=1) + self.samples.append((self.root / path_str, int(lbl_str))) + + def __len__(self): + """Return the number of samples in the dataset.""" + return len(self.samples) + + def __getitem__(self, idx): + """Load and return the (image, label) sample at index `idx`.""" + img_path, label = self.samples[idx] + img = Image.open(img_path).convert("RGB") + if self.transform: + img = self.transform(img) + return img, label + + +OOD_SPLITS = { + "CIFAR10": { + "test": { + "cifar10": "splits/cifar10/test_ood_cifar10.txt", + }, + "val": { + "tinyimagenet": "splits/cifar10/val_tin.txt", + }, + "near": { + "cifar100": "splits/cifar10/test_cifar100.txt", + "tinyimagenet": "splits/cifar10/test_tin.txt", + }, + "far": { + "mnist": "splits/cifar10/test_mnist.txt", + "svhn": "splits/cifar10/test_svhn.txt", + "texture": "splits/cifar10/test_texture.txt", + "places365": "splits/cifar10/test_places365.txt", + }, + }, + "CIFAR100": { + "test": { + "cifar100": "splits/cifar100/test_ood_cifar100.txt", + }, + "val": { + "tinyimagenet": "splits/cifar100/val_tin.txt", + }, + "near": { + "cifar10": "splits/cifar100/test_cifar10.txt", + "tinyimagenet": "splits/cifar100/test_tin.txt", + }, + "far": { + "mnist": "splits/cifar100/test_mnist.txt", + "svhn": "splits/cifar100/test_svhn.txt", + "texture": "splits/cifar100/test_texture.txt", + "places365": "splits/cifar100/test_places365.txt", + }, + }, + "imagenet200": { + "test": { + "imagenet1k": "splits/imagenet200/test_ood_imagenet200.txt", + }, + "val": { + "openimage_o": "splits/imagenet200/val_openimage_o.txt", + }, + "near": { + "ssb_hard": "splits/imagenet200/test_ssb_hard.txt", + "ninco": "splits/imagenet200/test_ninco.txt", + }, + "far": { + "inaturalist": "splits/imagenet200/test_inaturalist.txt", + "texture": "splits/imagenet200/test_textures.txt", + "openimage_o": "splits/imagenet200/test_openimage_o.txt", + }, + }, + "imagenet1k": { + "test": { + "imagenet1k": "splits/imagenet1k/test_ood_imagenet.txt", + }, + "val": { + "openimage_o": "splits/imagenet1k/val_openimage_o.txt", + }, + "near": { + "ssb_hard": "splits/imagenet1k/test_ssb_hard.txt", + "ninco": "splits/imagenet1k/test_ninco.txt", + }, + "far": { + "inaturalist": "splits/imagenet1k/test_inaturalist.txt", + "texture": "splits/imagenet1k/test_textures.txt", + "openimage_o": "splits/imagenet1k/test_openimage_o.txt", + }, + }, +} + + +ZENODO_INFO = { + "ninco": { + "url": "https://zenodo.org/record/8013288/files/NINCO_all.tar.gz", + "filename": "NINCO_all.tar.gz", + "extract_paths": ["NINCO"], + }, + "openimage_o": { + "url": "https://zenodo.org/records/10540831/files/OpenImage-O.zip", + "filename": "OpenImage-O.zip", + "extract_paths": ["openimage-o"], + }, +} + +HF_REPO_INFO: dict[str, dict[str, str]] = { + "cifar10": { + "repo_id": "torch-uncertainty/Cifar10", + "zip_filename": "cifar10.zip", + }, + "cifar100": { + "repo_id": "torch-uncertainty/Cifar100", + "zip_filename": "cifar100.zip", + }, + "mnist": { + "repo_id": "torch-uncertainty/MNIST", + "zip_filename": "mnist.zip", + }, + "texture": { + "repo_id": "torch-uncertainty/Texture", + "zip_filename": "texture.zip", + }, + "places365": { + "repo_id": "torch-uncertainty/Places365", + "zip_filename": "places365.zip", + }, + "svhn": { + "repo_id": "torch-uncertainty/SVHN", + "zip_filename": "svhn.zip", + }, + "tinyimagenet": { + "repo_id": "torch-uncertainty/tiny-imagenet-200", + "zip_filename": "tin.zip", + }, + "ssb_hard": { + "repo_id": "torch-uncertainty/SSB_hard", + "zip_filename": "ssb_hard.zip", + }, + "inaturalist": { + "repo_id": "torch-uncertainty/inaturalist", + "zip_filename": "inaturalist.zip", + }, + "imagenet1k": { + "repo_id": "torch-uncertainty/Imagenet1k", + "zip_filename": "imagenet_1k.zip", + }, +} + + +def download_and_extract_hf_dataset( + name: str, + root: Path, +) -> Path: + """- If name is 'ninco' or 'openimage_o', download from Zenodo and extract once. + - Otherwise fall back to HF_REPO_INFO + hf_hub_download. + Returns the path to the folder you should use as 'root' for FileListDataset. + """ + root = Path(root) + root.mkdir(parents=True, exist_ok=True) + + def _download_zenodo(): + logger.info("📥 Downloading '%s' from Zenodo…", name) + urllib.request.urlretrieve(info["url"], archive_path) # noqa: S310 + + def _attempt_extract_archive(): + logger.info("📂 Extracting '%s'…", archive_path.name) + if archive_path.suffix == ".zip": + with zipfile.ZipFile(archive_path, "r") as zf: + _safe_extract_zip(zf, root) + elif archive_path.suffixes[-2:] == [".tar", ".gz"]: + with tarfile.open(archive_path, "r:gz") as tf: + _safe_extract(tf, root) + else: + raise RuntimeError(f"Unknown archive format: {archive_path}") + + if name in ZENODO_INFO: + info = ZENODO_INFO[name] + archive_path = root / info["filename"] + + for rel in info["extract_paths"]: + candidate = root / rel + if candidate.exists(): + return candidate + + if not archive_path.exists(): + _download_zenodo() + + try: + _attempt_extract_archive() + except (RuntimeError, zipfile.BadZipFile, tarfile.TarError, OSError) as e: + logger.warning("Extraction failed (%s), re-downloading and retrying…", e) + archive_path.unlink(missing_ok=True) + _download_zenodo() + _attempt_extract_archive() + + for rel in info["extract_paths"]: + candidate = root / rel + if candidate.exists(): + return candidate + + raise RuntimeError( + f"Extraction succeeded but none of {info['extract_paths']} were found under {root!r}" + ) + + hf_info = HF_REPO_INFO.get(name) + if hf_info is None: + raise KeyError(f"No HF_REPO_INFO entry for {name}") + + repo_id = hf_info["repo_id"] + zip_fname = hf_info["zip_filename"] + target_dir = root / Path(zip_fname).stem + + if target_dir.exists(): + return target_dir + + target_dir.mkdir(parents=True, exist_ok=True) + logger.info( + "📥 Downloading %r from HF Hub (%s/%s)…", + name, + repo_id, + zip_fname, + ) + zip_path = hf_hub_download( + repo_id=repo_id, + filename=zip_fname, + repo_type="dataset", + ) + + def _extract_hf_zip(): + with zipfile.ZipFile(zip_path, "r") as zf: + _safe_extract_zip(zf, target_dir) + + try: + _extract_hf_zip() + except (zipfile.BadZipFile, OSError) as e: + logger.warning("HF Hub zip extract failed (%s), re-downloading and retrying…", e) + Path(zip_path).unlink(missing_ok=True) + zip_path = hf_hub_download( + repo_id=repo_id, + filename=zip_fname, + repo_type="dataset", + force_download=True, + ) + _extract_hf_zip() + + return target_dir + + +def get_ood_datasets( + root: str | Path, + dataset_id: str, + transform=None, +) -> tuple[FileListDataset, dict[str, FileListDataset], dict[str, FileListDataset]]: + """Ensure all OOD splits are downloaded and extracted via HF_REPO_INFO.""" + root = Path(root) + splits_base = download_and_extract_splits_from_hf(root=Path(root)) + + def _resolve_txt(rel_txt: str) -> Path: + rel = rel_txt.lstrip("/") + rel = rel.removeprefix("splits/") + return splits_base / rel + + cfg = OOD_SPLITS.get(dataset_id) + if cfg is None: + raise KeyError(f"No OOD_SPLITS for {dataset_id}") + + def build(name: str, rel_txt: str): + data_dir = download_and_extract_hf_dataset(name, root) + txt = _resolve_txt(rel_txt) + return FileListDataset(root=data_dir, list_file=txt, transform=transform, name=name) + + test_name, test_txt = next(iter(cfg["test"].items())) + test_ood = build(test_name, test_txt) + + val_name, val_txt = next(iter(cfg["val"].items())) + val_ood = build(val_name, val_txt) + + near_oods = {n: build(n, p) for n, p in cfg["near"].items()} + far_oods = {n: build(n, p) for n, p in cfg["far"].items()} + + return test_ood, val_ood, near_oods, far_oods + + +def download_and_extract_splits_from_hf( + root: str | Path, + repo_id="torch-uncertainty/ood-datasets-splits", + zip_filename="splits.zip", +) -> Path: + """Download a zip that contains the 'splits/' tree from HF and extract it once. + Returns the path to the extracted 'splits' directory (or the extracted root if it already is 'splits/'). + """ + root = Path(root) + root.mkdir(parents=True, exist_ok=True) + + target_dir = root / Path(zip_filename).stem # e.g. /splits + + def _is_valid_splits_dir(p: Path) -> bool: + # valid if it has a 'splits/' subdir OR known subfolders OR any .txt files inside + if (p / "splits").exists(): + return True + for sub in ("cifar10", "cifar100", "imagenet1k", "imagenet200"): + if (p / sub).exists(): + return True + return any(p.rglob("*.txt")) + + # EARLY RETURN ONLY IF VALID + if target_dir.exists() and _is_valid_splits_dir(target_dir): + return (target_dir / "splits") if (target_dir / "splits").exists() else target_dir + + # (Re)create and fetch + target_dir.mkdir(parents=True, exist_ok=True) + logger.info("📥 Downloading splits from HF Hub (%s/%s)…", repo_id, zip_filename) + zip_path = hf_hub_download( + repo_id=repo_id, + filename=zip_filename, + repo_type="dataset", # change to "model" if hosted as a model + ) + + def _extract_zip(zp: Path, out: Path): + with zipfile.ZipFile(zp, "r") as zf: + _safe_extract_zip(zf, out) + + try: + _extract_zip(Path(zip_path), target_dir) + except (zipfile.BadZipFile, OSError) as e: + logger.warning("Splits zip extract failed (%s), re-downloading and retrying…", e) + Path(zip_path).unlink(missing_ok=True) + zip_path = hf_hub_download( + repo_id=repo_id, + filename=zip_filename, + repo_type="dataset", + force_download=True, + ) + _extract_zip(Path(zip_path), target_dir) + + # Choose the actual splits dir to return + final_dir = (target_dir / "splits") if (target_dir / "splits").exists() else target_dir + + # VALIDATE POST-EXTRACT + if not _is_valid_splits_dir(final_dir): + raise FileNotFoundError( + f"No split files found under {final_dir}. " + f"Check the structure of {repo_id}:{zip_filename}." + ) + + return final_dir diff --git a/torch_uncertainty/models/classification/vit/__init__.py b/torch_uncertainty/models/classification/vit/__init__.py new file mode 100644 index 00000000..8cbedc30 --- /dev/null +++ b/torch_uncertainty/models/classification/vit/__init__.py @@ -0,0 +1,3 @@ +from .packedvit import PackedVit + +__all__ = ["PackedVit"] diff --git a/torch_uncertainty/models/classification/vit/packedvit.py b/torch_uncertainty/models/classification/vit/packedvit.py new file mode 100644 index 00000000..3a2b2b61 --- /dev/null +++ b/torch_uncertainty/models/classification/vit/packedvit.py @@ -0,0 +1,277 @@ +from collections import OrderedDict +from collections.abc import Callable +from functools import partial +from typing import NamedTuple + +import torch +import torch.nn as nn +from einops import rearrange + +from torch_uncertainty.layers import ( + PackedConv2d, + PackedLayerNorm, + PackedLinear, + PackedMultiheadAttention, +) + + +class ConvStemConfig(NamedTuple): + out_channels: int + kernel_size: int + stride: int + norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d + activation_layer: Callable[..., nn.Module] = nn.ReLU + + +class MLPBlock(nn.Module): + """Transformer MLP block.""" + + def __init__(self, in_dim: int, mlp_dim: int, dropout: float, num_estimators=1, alpha=1): + super().__init__() + self.layers = nn.Sequential( + PackedLinear( + in_dim, mlp_dim, num_estimators=num_estimators, alpha=alpha, implementation="einsum" + ), + nn.GELU(), + nn.Dropout(dropout), + PackedLinear( + mlp_dim, in_dim, num_estimators=num_estimators, alpha=alpha, implementation="einsum" + ), + nn.Dropout(dropout), + ) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + return self.layers(x) + + +class EncoderBlock(nn.Module): + """Transformer encoder block.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), # noqa: B008 + num_estimators=1, + alpha=1, + ): + super().__init__() + self.num_heads = num_heads + + # Attention block + self.ln_1 = PackedLayerNorm(hidden_dim, num_estimators, alpha) + self.self_attention = PackedMultiheadAttention( + hidden_dim, + num_heads, + dropout=attention_dropout, + num_estimators=num_estimators, + alpha=alpha, + batch_first=True, + ) + self.dropout = nn.Dropout(dropout) + + # MLP block + self.ln_2 = PackedLayerNorm(hidden_dim, num_estimators, alpha) + self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout, num_estimators, alpha) + + def forward(self, input: torch.Tensor): # noqa: A002 + torch._assert( + input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}" + ) + x = self.ln_1(input) + x, _ = self.self_attention(x, x, x, need_weights=False) + x = self.dropout(x) + x = x + input + + y = self.ln_2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), # noqa: B008 + num_estimators=1, + alpha=1, + ): + super().__init__() + self.pos_embedding = nn.Parameter( + torch.empty(1, seq_length, int(hidden_dim * alpha)).normal_(std=0.02) + ) + self.dropout = nn.Dropout(dropout) + layers: OrderedDict[str, nn.Module] = OrderedDict() + for i in range(num_layers): + layers[f"encoder_layer_{i}"] = EncoderBlock( + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + num_estimators, + alpha, + ) + self.layers = nn.Sequential(layers) + self.ln = PackedLayerNorm(hidden_dim, num_estimators, alpha) + + def forward(self, input: torch.Tensor): # noqa: A002 + torch._assert( + input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}" + ) + input = input + self.pos_embedding # noqa: A001 + return self.ln(self.layers(self.dropout(input))) + + +class PackedVit(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float = 0.0, + attention_dropout: float = 0.0, + num_classes: int = 1000, + representation_size: int | None = None, + norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6), # noqa: B008 + conv_stem_configs: list[ConvStemConfig] | None = None, + num_estimators=1, + alpha=1, + ): + super().__init__() + torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") + self.image_size = image_size + self.patch_size = patch_size + self.hidden_dim = hidden_dim + + if conv_stem_configs is not None: + seq_proj = nn.Sequential() + prev_channels = 3 + for i, conv_stem_layer_config in enumerate(conv_stem_configs): + seq_proj.add_module( + f"conv_{i}", + nn.Conv2d( + in_channels=prev_channels, + out_channels=conv_stem_layer_config.out_channels, + kernel_size=conv_stem_layer_config.kernel_size, + stride=conv_stem_layer_config.stride, + ), + ) + seq_proj.add_module( + f"bn_{i}", + conv_stem_layer_config.norm_layer(conv_stem_layer_config.out_channels), + ) + seq_proj.add_module(f"relu_{i}", conv_stem_layer_config.activation_layer()) + prev_channels = conv_stem_layer_config.out_channels + seq_proj.add_module( + "conv_last", + nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1), + ) + self.conv_proj = seq_proj + else: + self.conv_proj = PackedConv2d( + in_channels=3, + out_channels=hidden_dim, + kernel_size=patch_size, + stride=patch_size, + first=True, + num_estimators=num_estimators, + alpha=alpha, + ) + + seq_length = (image_size // patch_size) ** 2 + self.class_token = nn.Parameter(torch.zeros(1, 1, int(hidden_dim * alpha))) + seq_length += 1 + + self.encoder = Encoder( + seq_length, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + num_estimators, + alpha, + ) + + heads_layers: OrderedDict[str, nn.Module] = OrderedDict() + if representation_size is None: + heads_layers["head"] = PackedLinear( + hidden_dim, + num_classes, + num_estimators=num_estimators, + alpha=alpha, + implementation="einsum", + last=True, + ) + else: + heads_layers["pre_logits"] = PackedLinear( + hidden_dim, + representation_size, + num_estimators=num_estimators, + alpha=alpha, + implementation="einsum", + ) + heads_layers["act"] = nn.Tanh() + heads_layers["head"] = PackedLinear( + representation_size, + num_classes, + num_estimators=num_estimators, + alpha=alpha, + implementation="einsum", + last=True, + ) + + self.heads = nn.Sequential(heads_layers) + self.alpha = alpha + self.num_estimators = num_estimators + + def _process_input(self, x: torch.Tensor) -> torch.Tensor: + n, c, h, w = x.shape + p = self.patch_size + torch._assert( + h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!" + ) + torch._assert( + w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!" + ) + n_h = h // p + n_w = w // p + x = self.conv_proj(x) + x = x.reshape(n, int(self.hidden_dim * self.alpha), n_h * n_w) + x = x.permute(0, 2, 1) + return x # noqa: RET504 + + def forward(self, x: torch.Tensor): + x = self._process_input(x) + n = x.shape[0] + batch_class_token = self.class_token.expand(n, -1, -1) + x = torch.cat([batch_class_token, x], dim=1) + x = self.encoder(x) + x = x[:, 0] + x = self.heads(x) + out = rearrange(x, "b (m c) -> (m b) c", m=self.num_estimators) + return out # noqa: RET504 diff --git a/torch_uncertainty/ood/configs/adascale_a.yml b/torch_uncertainty/ood/configs/adascale_a.yml new file mode 100644 index 00000000..29766d03 --- /dev/null +++ b/torch_uncertainty/ood/configs/adascale_a.yml @@ -0,0 +1,25 @@ +postprocessor: + name: adascale_a + APS_mode: True + postprocessor_args: + num_samples: 5000 + percentile: [75, 85] + k1: 5 + k2: 1 + lmbda: 10 + o: 0.05 + postprocessor_sweep: + percentile_list: [[60, 65], [60, 70], [60, 75], [60, 80], [60, 85], [60, 90], [60, 95], [60, 99], [65, 70], [65, 75], [65, 80], [65, 85], [65, 90], [65, 95], [65, 99], [70, 75], [70, 80], [70, 85], [70, 90], [70, 95], [70, 99], [75, 80], [75, 85], [75, 90], [75, 95], [75, 99], [80, 85], [80, 90], [80, 95], [80, 99], [85, 90], [85, 95], [85, 99], [90, 95], [90, 99], [95, 99]] + k1_list: [5] #[5, 10, 20, 50, 80, 100] + # k1 hyperparameter is not highly critical, feel free to set it + # to 5% across all new architectures for near-optimal results + k2_list: [1] + lmbda_list: [10] + o_list: [0.05] + # For ResNet-50 model, following hyperparameters works best + # percentile_list: [[80, 85]] + # k1_list: [5] + # k2_list: [1] + # lmbda_list: [10] + # o_list: [0.05] + # For hyperparameters of other models, please refer to https: TODO diff --git a/torch_uncertainty/ood/configs/ash.yml b/torch_uncertainty/ood/configs/ash.yml new file mode 100644 index 00000000..3b1e7fb4 --- /dev/null +++ b/torch_uncertainty/ood/configs/ash.yml @@ -0,0 +1,7 @@ +postprocessor: + name: ash + APS_mode: True + postprocessor_args: + percentile: 90 + postprocessor_sweep: + percentile_list: [65, 70, 75, 80, 85, 90, 95] diff --git a/torch_uncertainty/ood/configs/gen.yml b/torch_uncertainty/ood/configs/gen.yml new file mode 100644 index 00000000..11f9d285 --- /dev/null +++ b/torch_uncertainty/ood/configs/gen.yml @@ -0,0 +1,9 @@ +postprocessor: + name: gen + APS_mode: True + postprocessor_args: + gamma: 0.1 + m: 100 + postprocessor_sweep: + gamma_list: [0.01,0.1,0.5,1,2,5,10] + M_list: [10,50,100,200,500,1000] diff --git a/torch_uncertainty/ood/configs/knn.yml b/torch_uncertainty/ood/configs/knn.yml new file mode 100644 index 00000000..7b29e807 --- /dev/null +++ b/torch_uncertainty/ood/configs/knn.yml @@ -0,0 +1,7 @@ +postprocessor: + name: knn + APS_mode: True + postprocessor_args: + K: 50 + postprocessor_sweep: + K_list: [50, 100, 200, 500, 1000] diff --git a/torch_uncertainty/ood/configs/nnguide.yml b/torch_uncertainty/ood/configs/nnguide.yml new file mode 100644 index 00000000..5a6b68af --- /dev/null +++ b/torch_uncertainty/ood/configs/nnguide.yml @@ -0,0 +1,9 @@ +postprocessor: + name: nnguide + APS_mode: False + postprocessor_args: + alpha : 0.01 + K: 100 + postprocessor_sweep: + K_list: [100] + alpha_list: [0.01] diff --git a/torch_uncertainty/ood/configs/odin.yml b/torch_uncertainty/ood/configs/odin.yml new file mode 100644 index 00000000..3f192fbc --- /dev/null +++ b/torch_uncertainty/ood/configs/odin.yml @@ -0,0 +1,9 @@ +postprocessor: + name: odin + APS_mode: True + postprocessor_args: + temperature: 1000 + noise: 0.0014 + postprocessor_sweep: + temperature: [1, 10, 100, 1000] + noise: [0.0014, 0.0028] diff --git a/torch_uncertainty/ood/configs/react.yml b/torch_uncertainty/ood/configs/react.yml new file mode 100644 index 00000000..1e63a8a4 --- /dev/null +++ b/torch_uncertainty/ood/configs/react.yml @@ -0,0 +1,7 @@ +postprocessor: + name: react + APS_mode: True + postprocessor_args: + percentile: 90 + postprocessor_sweep: + percentile_list: [85, 90, 95, 99] diff --git a/torch_uncertainty/ood/configs/scale.yml b/torch_uncertainty/ood/configs/scale.yml new file mode 100644 index 00000000..5027abff --- /dev/null +++ b/torch_uncertainty/ood/configs/scale.yml @@ -0,0 +1,7 @@ +postprocessor: + name: scale + APS_mode: True + postprocessor_args: + percentile: 85 + postprocessor_sweep: + percentile_list: [65, 70, 75, 80, 85, 90, 95] diff --git a/torch_uncertainty/ood/configs/vim.yml b/torch_uncertainty/ood/configs/vim.yml new file mode 100644 index 00000000..6d7a8c51 --- /dev/null +++ b/torch_uncertainty/ood/configs/vim.yml @@ -0,0 +1,7 @@ +postprocessor: + name: vim + APS_mode: True + postprocessor_args: + dim: 256 + postprocessor_sweep: + dim_list: [256, 1000] diff --git a/torch_uncertainty/ood/nets/__init__.py b/torch_uncertainty/ood/nets/__init__.py new file mode 100644 index 00000000..b961763d --- /dev/null +++ b/torch_uncertainty/ood/nets/__init__.py @@ -0,0 +1,5 @@ +# ruff: noqa: F401 +from .adascale_net import AdaScaleANet, AdaScaleLNet +from .ash_net import ASHNet +from .react_net import ReactNet +from .scale_net import ScaleNet diff --git a/torch_uncertainty/ood/nets/adascale_net.py b/torch_uncertainty/ood/nets/adascale_net.py new file mode 100644 index 00000000..bab85335 --- /dev/null +++ b/torch_uncertainty/ood/nets/adascale_net.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn + + +class AdaScaleANet(nn.Module): + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + self.logit_scaling = False + + def forward(self, x, return_feature=False, return_feature_list=False): + try: + return self.backbone(x, return_feature, return_feature_list) + except TypeError: + return self.backbone(x, return_feature) + + def forward_threshold(self, feature, percentiles): + scale = ada_scale(torch.relu(feature), percentiles) + if self.logit_scaling: + logits_cls = self.backbone.get_fc_layer()(feature) + logits_cls *= scale**2.0 + else: + feature *= torch.exp(scale) + logits_cls = self.backbone.get_fc_layer()(feature) + return logits_cls + + +class AdaScaleLNet(AdaScaleANet): + def __init__(self, backbone): + super().__init__() + self.logit_scaling = True + + +def ada_scale(x, percentiles): + assert x.dim() == 2 + b, c = x.shape + assert percentiles.shape == (b,) + assert x.dim() == 2, "input tensor must be 2D" + assert torch.all(percentiles > 0), "percentiles must be > 0" + assert torch.all(percentiles < 100), "percentiles must be < 100" + n = x.shape[1:].numel() + ks = n - torch.round(n * percentiles.cuda() / 100.0).to(torch.int) + max_k = ks.max() + values, _ = torch.topk(x, max_k, dim=1) + mask = torch.arange(max_k, device=x.device)[None, :] < ks[:, None] + batch_sums = x.sum(dim=1, keepdim=True) + masked_values = values * mask + topk_sums = masked_values.sum(dim=1, keepdim=True) + return batch_sums / topk_sums diff --git a/torch_uncertainty/ood/nets/ash_net.py b/torch_uncertainty/ood/nets/ash_net.py new file mode 100644 index 00000000..236ba670 --- /dev/null +++ b/torch_uncertainty/ood/nets/ash_net.py @@ -0,0 +1,94 @@ +import numpy as np +import torch +import torch.nn as nn + + +class ASHNet(nn.Module): + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, x, return_feature=False, return_feature_list=False): + try: + return self.backbone(x, return_feature, return_feature_list) + except TypeError: + return self.backbone(x, return_feature) + + def forward_threshold(self, x, percentile): + _, feature = self.backbone(x, return_feature=True) + feature = ash_b(feature.view(feature.size(0), -1, 1, 1), percentile) + feature = feature.view(feature.size(0), -1) + return self.backbone.get_fc_layer()(feature) + + def get_fc(self): + fc = self.backbone.fc + return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy() + + +def ash_b(x, percentile=65): + assert x.dim() == 4 + assert 0 <= percentile <= 100 + b, c, h, w = x.shape + + # calculate the sum of the input per sample + s1 = x.sum(dim=[1, 2, 3]) + + n = x.shape[1:].numel() + k = n - int(np.round(n * percentile / 100.0)) + t = x.view((b, c * h * w)) + v, i = torch.topk(t, k, dim=1) + fill = s1 / k + fill = fill.unsqueeze(dim=1).expand(v.shape) + t.zero_().scatter_(dim=1, index=i, src=fill) + return x + + +def ash_p(x, percentile=65): + assert x.dim() == 4 + assert 0 <= percentile <= 100 + + b, c, h, w = x.shape + + n = x.shape[1:].numel() + k = n - int(np.round(n * percentile / 100.0)) + t = x.view((b, c * h * w)) + v, i = torch.topk(t, k, dim=1) + t.zero_().scatter_(dim=1, index=i, src=v) + + return x + + +def ash_s(x, percentile=65): + assert x.dim() == 4 + assert 0 <= percentile <= 100 + b, c, h, w = x.shape + + # calculate the sum of the input per sample + s1 = x.sum(dim=[1, 2, 3]) + n = x.shape[1:].numel() + k = n - int(np.round(n * percentile / 100.0)) + t = x.view((b, c * h * w)) + v, i = torch.topk(t, k, dim=1) + t.zero_().scatter_(dim=1, index=i, src=v) + + # calculate new sum of the input per sample after pruning + s2 = x.sum(dim=[1, 2, 3]) + + # apply sharpening + scale = s1 / s2 + + return x * torch.exp(scale[:, None, None, None]) + + +def ash_rand(x, percentile=65, r1=0, r2=10): + assert x.dim() == 4 + assert 0 <= percentile <= 100 + b, c, h, w = x.shape + + n = x.shape[1:].numel() + k = n - int(np.round(n * percentile / 100.0)) + t = x.view((b, c * h * w)) + v, i = torch.topk(t, k, dim=1) + v = v.uniform_(r1, r2) + t.zero_().scatter_(dim=1, index=i, src=v) + return x diff --git a/torch_uncertainty/ood/nets/react_net.py b/torch_uncertainty/ood/nets/react_net.py new file mode 100644 index 00000000..6037d793 --- /dev/null +++ b/torch_uncertainty/ood/nets/react_net.py @@ -0,0 +1,23 @@ +import torch.nn as nn + + +class ReactNet(nn.Module): + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, x, return_feature=False, return_feature_list=False): + try: + return self.backbone(x, return_feature, return_feature_list) + except TypeError: + return self.backbone(x, return_feature) + + def forward_threshold(self, x, threshold): + _, feature = self.backbone(x, return_feature=True) + feature = feature.clip(max=threshold) + feature = feature.view(feature.size(0), -1) + return self.backbone.get_fc_layer()(feature) + + def get_fc(self): + fc = self.backbone.fc + return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy() diff --git a/torch_uncertainty/ood/nets/scale_net.py b/torch_uncertainty/ood/nets/scale_net.py new file mode 100644 index 00000000..f80c3703 --- /dev/null +++ b/torch_uncertainty/ood/nets/scale_net.py @@ -0,0 +1,48 @@ +import numpy as np +import torch +import torch.nn as nn + + +class ScaleNet(nn.Module): + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, x, return_feature=False, return_feature_list=False): + try: + return self.backbone(x, return_feature, return_feature_list) + except TypeError: + return self.backbone(x, return_feature) + + def forward_threshold(self, x, percentile): + _, feature = self.backbone(x, return_feature=True) + feature = scale(feature.view(feature.size(0), -1, 1, 1), percentile) + feature = feature.view(feature.size(0), -1) + return self.backbone.get_fc_layer()(feature) + + def get_fc(self): + fc = self.backbone.fc + return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy() + + +def scale(x, percentile=65): + x_clone = x.clone() + assert x.dim() == 4 + assert 0 <= percentile <= 100 + b, c, h, w = x.shape + + # calculate the sum of the input per sample + s1 = x.sum(dim=[1, 2, 3]) + n = x.shape[1:].numel() + k = n - int(np.round(n * percentile / 100.0)) + t = x.view((b, c * h * w)) + v, i = torch.topk(t, k, dim=1) + t.zero_().scatter_(dim=1, index=i, src=v) + + # calculate new sum of the input per sample after pruning + s2 = x.sum(dim=[1, 2, 3]) + + # apply sharpening + scale = s1 / s2 + + return x_clone * torch.exp(scale[:, None, None, None]) diff --git a/torch_uncertainty/ood/ood_criteria.py b/torch_uncertainty/ood/ood_criteria.py new file mode 100644 index 00000000..7377907f --- /dev/null +++ b/torch_uncertainty/ood/ood_criteria.py @@ -0,0 +1,980 @@ +import logging +from abc import ABC, abstractmethod +from copy import deepcopy +from enum import Enum +from pathlib import Path +from typing import Any + +import faiss +import numpy as np +import torch +from numpy.linalg import norm, pinv +from scipy.special import logsumexp +from sklearn.covariance import EmpiricalCovariance +from statsmodels.distributions.empirical_distribution import ECDF +from torch import Tensor, nn +from tqdm import tqdm + +from torch_uncertainty.metrics import MutualInformation, VariationRatio +from torch_uncertainty.ood.nets import ( + AdaScaleANet, + ASHNet, + ReactNet, + ScaleNet, +) +from torch_uncertainty.ood.utils import load_config + +logger = logging.getLogger(__name__) + + +def normalizer(x): + return x / np.linalg.norm(x, axis=-1, keepdims=True) + 1e-10 + + +class OODCriterionInputType(Enum): + """Enum representing the type of input expected by the OOD (Out-of-Distribution) criteria. + + Attributes: + LOGIT (int): Represents that the input is in the form of logits (pre-softmax values). + PROB (int): Represents that the input is in the form of probabilities (post-softmax values). + ESTIMATOR_PROB (int): Represents that the input is in the form of estimated probabilities + from an ensemble or other probabilistic model. + """ + + LOGIT = 1 + PROB = 2 + ESTIMATOR_PROB = 3 + DATASET = 4 + + +class TUOODCriterion(ABC, nn.Module): + single_only = True + input_type: OODCriterionInputType + ensemble_only = False + + def __init__(self) -> None: + """Abstract base class for Out-of-Distribution (OOD) criteria. + + This class defines a common interface for implementing various OOD detection + criteria. Subclasses must implement the `forward` method. + + Attributes: + input_type (OODCriterionInputType): Type of input expected by the criterion. + ensemble_only (bool): Whether the criterion requires ensemble outputs. + """ + super().__init__() + self.setup_flag = False + self.hyperparam_search_done = False + + def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict): + pass + + @abstractmethod + def forward(self, inputs: Tensor) -> Tensor: + """Forward pass for the OOD criterion. + + Args: + inputs (Tensor): The input tensor representing model outputs. + + Returns: + Tensor: OOD score computed according to the criterion. + """ + + +class MaxLogitCriterion(TUOODCriterion): + input_type = OODCriterionInputType.LOGIT + + def __init__(self) -> None: + """OOD criterion based on the maximum logit value. + + This criterion computes the negative of the highest logit value across + the output dimensions. Lower maximum logits indicate greater uncertainty. + + Attributes: + input_type (OODCriterionInputType): Expected input type is logits. + """ + super().__init__() + + def forward(self, inputs: Tensor) -> Tensor: + """Compute the negative of the maximum logit value. + + Args: + inputs (Tensor): Tensor of logits with shape (batch_size, num_classes). + + Returns: + Tensor: Negative of the maximum logit value for each sample. + """ + return -inputs.mean(dim=1).max(dim=-1).values + + +class EnergyCriterion(TUOODCriterion): + input_type = OODCriterionInputType.LOGIT + + def __init__(self) -> None: + r"""OOD criterion based on the energy function. + + This criterion computes the negative log-sum-exp of the logits. + Higher energy values indicate greater uncertainty. + + .. math:: + E(\mathbf{z}) = -\log\left(\sum_{i=1}^{C} \exp(z_i)\right) + + where :math:`\mathbf{z} = [z_1, z_2, \dots, z_C]` is the logit vector. + + Attributes: + input_type (OODCriterionInputType): Expected input type is logits. + """ + super().__init__() + + def forward(self, inputs: Tensor) -> Tensor: + """Compute the negative energy score. + + Args: + inputs (Tensor): Tensor of logits with shape (batch_size, num_classes). + + Returns: + Tensor: Negative energy score for each sample. + """ + return -inputs.mean(dim=1).logsumexp(dim=-1) + + +class MaxSoftmaxCriterion(TUOODCriterion): + single_only = False + input_type = OODCriterionInputType.PROB + + def __init__(self) -> None: + r"""OOD criterion based on maximum softmax probability. + + This criterion computes the negative of the highest softmax probability. + Lower maximum probabilities indicate greater uncertainty. + + .. math:: + \text{score} = -\max_{i}(p_i) + + where :math:`\mathbf{p} = [p_1, p_2, \dots, p_C]` is the probability vector. + + Attributes: + input_type (OODCriterionInputType): Expected input type is probabilities. + """ + super().__init__() + + def forward(self, inputs: Tensor) -> Tensor: + """Compute the negative of the maximum softmax probability. + + Args: + inputs (Tensor): Tensor of probabilities with shape (batch_size, num_classes). + + Returns: + Tensor: Negative of the highest softmax probability for each sample. + """ + return -inputs.max(-1)[0] + + +class EntropyCriterion(TUOODCriterion): + single_only = False + input_type = OODCriterionInputType.ESTIMATOR_PROB + + def __init__(self) -> None: + r"""OOD criterion based on entropy. + + This criterion computes the mean entropy of the predicted probability distribution. + Higher entropy values indicate greater uncertainty. + + .. math:: + H(\mathbf{p}) = -\sum_{i=1}^{C} p_i \log(p_i) + + where :math:`\mathbf{p} = [p_1, p_2, \dots, p_C]` is the probability vector. + + Attributes: + input_type (OODCriterionInputType): Expected input type is estimated probabilities. + """ + super().__init__() + + def forward(self, inputs: Tensor) -> Tensor: + """Compute the entropy of the predicted probability distribution. + + Args: + inputs (Tensor): Tensor of estimated probabilities with shape (batch_size, num_classes). + + Returns: + Tensor: Mean entropy value for each sample. + """ + return torch.special.entr(inputs).sum(dim=-1).mean(dim=1) + + +class MutualInformationCriterion(TUOODCriterion): + single_only = False + ensemble_only = True + input_type = OODCriterionInputType.ESTIMATOR_PROB + + def __init__(self) -> None: + r"""OOD criterion based on mutual information. + + This criterion computes the mutual information between ensemble predictions. + Higher mutual information values indicate lower uncertainty. + + Given ensemble predictions :math:`\{\mathbf{p}^{(k)}\}_{k=1}^{K}`, the mutual information is computed as: + + .. math:: + I(y, \theta) = H\Big(\frac{1}{K}\sum_{k=1}^{K} \mathbf{p}^{(k)}\Big) - \frac{1}{K}\sum_{k=1}^{K} H(\mathbf{p}^{(k)}) + + Attributes: + ensemble_only (bool): Requires ensemble predictions. + input_type (OODCriterionInputType): Expected input type is estimated probabilities. + """ + super().__init__() + self.mi_metric = MutualInformation(reduction="none") + + def forward(self, inputs: Tensor) -> Tensor: + """Compute mutual information from ensemble predictions. + + Args: + inputs (Tensor): Tensor of ensemble probabilities with shape + (ensemble_size, batch_size, num_classes). + + Returns: + Tensor: Mutual information for each sample. + """ + return self.mi_metric(inputs) + + +class VariationRatioCriterion(TUOODCriterion): + single_only = False + ensemble_only = True + input_type = OODCriterionInputType.ESTIMATOR_PROB + + def __init__(self) -> None: + r"""OOD criterion based on variation ratio. + + This criterion computes the variation ratio from ensemble predictions. + Higher variation ratio values indicate greater uncertainty. + + Given ensemble predictions where :math:`n_{\text{mode}}` is the count of the most frequently + predicted class among :math:`K` predictions, the variation ratio is computed as: + + .. math:: + \text{VR} = 1 - \frac{n_{\text{mode}}}{K} + + Attributes: + ensemble_only (bool): Requires ensemble predictions. + input_type (OODCriterionInputType): Expected input type is estimated probabilities. + """ + super().__init__() + self.vr_metric = VariationRatio(reduction="none", probabilistic=False) + + def forward(self, inputs: Tensor) -> Tensor: + """Compute variation ratio from ensemble predictions. + + Args: + inputs (Tensor): Tensor of ensemble probabilities with shape + (ensemble_size, batch_size, num_classes). + + Returns: + Tensor: Variation ratio for each sample. + """ + return self.vr_metric(inputs.transpose(0, 1)) + + +class ScaleCriterion(TUOODCriterion): + input_type = OODCriterionInputType.DATASET + + def __init__(self, config) -> None: + """OOD criterion based on the Scale method. + + This criterion uses a scaling-based approach to compute OOD scores. + It applies a thresholding mechanism to the network's output and computes + the energy confidence score. The lower the energy confidence, the higher + the uncertainty. + + Heavily inspired by the OpenOOD repository: + https://github.com/Jingkang50/OpenOOD + + Attributes: + input_type (OODCriterionInputType): Expected input type is dataset. + percentile (float): Percentile value used for thresholding. + args_dict (dict): Dictionary containing hyperparameter sweep configurations. + """ + super().__init__() + self.args = config.postprocessor.postprocessor_args + self.percentile = self.args.percentile + self.args_dict = config.postprocessor.postprocessor_sweep + + def forward(self, net: nn.Module, data: Any) -> Tensor: + net = ScaleNet(net) + output = net.forward_threshold(data, self.percentile) + energyconf = torch.logsumexp(output.data, dim=1) + return -energyconf + + def set_hyperparam(self, hyperparam: list): + self.percentile = hyperparam[0] + + def get_hyperparam(self): + return self.percentile + + +class ASHCriterion(TUOODCriterion): + input_type = OODCriterionInputType.DATASET + + def __init__(self, config) -> None: + """OOD criterion based on the ASH (Activation Shift) method. + + This criterion uses a thresholding mechanism to compute OOD scores + based on the network's activations. It applies a percentile-based + threshold to the activations and computes the energy confidence score. + Lower energy confidence indicates higher uncertainty. + + Heavily inspired by the OpenOOD repository: + https://github.com/Jingkang50/OpenOOD + + Attributes: + input_type (OODCriterionInputType): Expected input type is dataset. + percentile (float): Percentile value used for thresholding. + args_dict (dict): Dictionary containing hyperparameter sweep configurations. + """ + super().__init__() + self.args = config.postprocessor.postprocessor_args + self.percentile = self.args.percentile + self.args_dict = config.postprocessor.postprocessor_sweep + + @torch.no_grad() + def forward(self, net: nn.Module, data: Any): + net = ASHNet(net) + output = net.forward_threshold(data, self.percentile) + energyconf = torch.logsumexp(output.data, dim=1) + return -energyconf + + def set_hyperparam(self, hyperparam: list): + self.percentile = hyperparam[0] + + def get_hyperparam(self): + return self.percentile + + +class ReactCriterion(TUOODCriterion): + input_type = OODCriterionInputType.DATASET + + def __init__(self, config) -> None: + """OOD criterion based on the React (Rectified Activation) method. + + This criterion uses a thresholding mechanism to compute OOD scores + based on the network's activations. It applies a percentile-based + threshold to the activations and computes the energy confidence score. + Lower energy confidence indicates higher uncertainty. + + Heavily inspired by the OpenOOD repository: + https://github.com/Jingkang50/OpenOOD + + Attributes: + input_type (OODCriterionInputType): Expected input type is dataset. + percentile (float): Percentile value used for thresholding. + args_dict (dict): Dictionary containing hyperparameter sweep configurations. + """ + super().__init__() + self.args = config.postprocessor.postprocessor_args + self.percentile = self.args.percentile + self.args_dict = config.postprocessor.postprocessor_sweep + + def setup(self, net: nn.Module, id_loader, ood_loaders): + if not self.setup_flag: + activation_log = [] + net = ReactNet(net) + net.eval() + with torch.no_grad(): + for batch in tqdm(id_loader["val"], desc="Setup: ", position=0, leave=True): + data = batch[0].cuda().float() + _, feature = net(data, return_feature=True) + activation_log.append(feature.data.cpu().numpy()) + + self.activation_log = np.concatenate(activation_log, axis=0) + self.setup_flag = True + else: + pass + + self.threshold = np.percentile(self.activation_log.flatten(), self.percentile) + + @torch.no_grad() + def forward(self, net: nn.Module, data: Any): + net = ReactNet(net) + output = net.forward_threshold(data, self.threshold) + energyconf = torch.logsumexp(output.data, dim=1) + return -energyconf + + def set_hyperparam(self, hyperparam: list): + self.percentile = hyperparam[0] + self.threshold = np.percentile(self.activation_log.flatten(), self.percentile) + logger.info( + "Threshold at percentile %d over id data is: %s", + self.percentile, + self.threshold, + ) + + def get_hyperparam(self): + return self.percentile + + +class AdaScaleCriterion(TUOODCriterion): + input_type = OODCriterionInputType.DATASET + + def __init__(self, config) -> None: + """OOD criterion based on the AdaScale method. + + This criterion uses an adaptive scaling approach to compute OOD scores. + It applies a percentile-based thresholding mechanism to the network's + features and computes the energy confidence score. Lower energy confidence + indicates higher uncertainty. + + Heavily inspired by the OpenOOD repository: + https://github.com/Jingkang50/OpenOOD + + Attributes: + input_type (OODCriterionInputType): Expected input type is dataset. + percentile (float): Percentile value used for thresholding. + k1 (int): Number of top-k features considered for correction term. + k2 (int): Number of top-k features considered for feature shift. + lmbda (float): Scaling factor for feature shift. + o (float): Fraction of pixels used for perturbation. + args_dict (dict): Dictionary containing hyperparameter sweep configurations. + """ + super().__init__() + self.args = config.postprocessor.postprocessor_args + self.percentile = self.args.percentile + self.k1 = self.args.k1 + self.k2 = self.args.k2 + self.lmbda = self.args.lmbda + self.o = self.args.o + self.args_dict = config.postprocessor.postprocessor_sweep + + def setup(self, net: nn.Module, id_loader, ood_loaders): + net = AdaScaleANet(net) + if not self.setup_flag: + feature_log = [] + feature_perturbed_log = [] + feature_shift_log = [] + net.eval() + self.feature_dim = net.backbone.feature_size + with torch.no_grad(): + for batch in tqdm(id_loader["val"], desc="Setup: ", position=0, leave=True): + data = batch[0].cuda().float() + with torch.enable_grad(): + data.requires_grad = True + output, feature = net(data, return_feature=True) + labels = output.detach().argmax(dim=1) + net.zero_grad() + score = output[torch.arange(len(labels)), labels] + score.backward(torch.ones_like(labels)) + grad = data.grad.data.detach() + feature_log.append(feature.data.cpu()) + data_perturbed = self.perturb(data, grad) + _, feature_perturbed = net(data_perturbed, return_feature=True) + feature_shift = abs(feature - feature_perturbed) + feature_shift_log.append(feature_shift.data.cpu()) + feature_perturbed_log.append(feature_perturbed.data.cpu()) + all_features = torch.cat(feature_log, axis=0) + all_perturbed = torch.cat(feature_perturbed_log, axis=0) + all_shifts = torch.cat(feature_shift_log, axis=0) + + total_samples = all_features.size(0) + num_samples = ( + self.args.num_samples if hasattr(self.args, "num_samples") else total_samples + ) + indices = torch.randperm(total_samples)[:num_samples] + + self.feature_log = all_features[indices] + self.feature_perturbed_log = all_perturbed[indices] + self.feature_shift_log = all_shifts[indices] + self.setup_flag = True + else: + pass + + @torch.no_grad() + def get_percentile(self, feature, feature_perturbed, feature_shift): + topk_indices = torch.topk(feature, dim=1, k=self.k1_)[1] + topk_feature_perturbed = torch.gather( + torch.relu(feature_perturbed), 1, topk_indices + ) # correction term C_o + topk_indices = torch.topk(feature, dim=1, k=self.k2_)[1] + topk_feature_shift = torch.gather(feature_shift, 1, topk_indices) # Q + topk_norm = topk_feature_perturbed.sum(dim=1) + self.lmbda * topk_feature_shift.sum( + dim=1 + ) # Q^{\prime} + percent = 1 - self.ecdf(topk_norm.cpu()) + percentile = self.min_percentile + percent * (self.max_percentile - self.min_percentile) + return torch.from_numpy(percentile) + + @torch.no_grad() + def forward(self, net: nn.Module, data): + net = AdaScaleANet(net) + with torch.enable_grad(): + data.requires_grad = True + output, feature = net(data, return_feature=True) + labels = output.detach().argmax(dim=1) + net.zero_grad() + score = output[torch.arange(len(labels)), labels] + score.backward(torch.ones_like(labels)) + grad = data.grad.data.detach() + data.requires_grad = False + data_perturbed = self.perturb(data, grad) + _, feature_perturbed = net(data_perturbed, return_feature=True) + feature_shift = abs(feature - feature_perturbed) + percentile = self.get_percentile(feature, feature_perturbed, feature_shift) + output = net.forward_threshold(feature, percentile) + conf = torch.logsumexp(output, dim=1) + return -conf + + @torch.no_grad() + def perturb(self, data, grad): + batch_size, channels, height, width = data.shape + n_pixels = int(channels * height * width * self.o) + abs_grad = abs(grad).view(batch_size, channels * height * width) + _, topk_indices = torch.topk(abs_grad, n_pixels, dim=1, largest=False) + mask = torch.zeros_like(abs_grad, dtype=torch.uint8) + mask.scatter_(1, topk_indices, 1) + mask = mask.view(batch_size, channels, height, width) + return data + grad.sign() * mask * 0.5 + + def set_hyperparam(self, hyperparam: list): + self.percentile = hyperparam[0] + self.min_percentile, self.max_percentile = self.percentile[0], self.percentile[1] + self.k1 = hyperparam[1] + self.k2 = hyperparam[2] + self.lmbda = hyperparam[3] + self.o = hyperparam[4] + self.k1_ = int(self.feature_dim * self.k1 / 100) + self.k2_ = int(self.feature_dim * self.k2 / 100) + topk_indices = torch.topk(self.feature_log, k=self.k1_, dim=1)[1] + topk_feature_perturbed = torch.gather( + torch.relu(self.feature_perturbed_log), 1, topk_indices + ) + topk_indices = torch.topk(self.feature_log, k=self.k2_, dim=1)[1] + topk_feature_shift_log = torch.gather(self.feature_shift_log, 1, topk_indices) + sum_log = topk_feature_perturbed.sum(dim=1) + self.lmbda * topk_feature_shift_log.sum(dim=1) + self.ecdf = ECDF(sum_log) + + def get_hyperparam(self): + return [self.percentile, self.k1, self.k2, self.lmbda, self.o] + + +class VIMCriterion(TUOODCriterion): + input_type = OODCriterionInputType.DATASET + + def __init__(self, config) -> None: + """OOD criterion based on the VIM (Variance-Informed Mahalanobis) method. + + This criterion uses a Mahalanobis distance-based approach to compute OOD scores. + It extracts features from the in-distribution training data, computes the + empirical covariance matrix, and projects the features onto a subspace + defined by the top eigenvectors. The OOD score is computed as a combination + of the variance-informed Mahalanobis distance and the energy score. + + Heavily inspired by the OpenOOD repository: + https://github.com/Jingkang50/OpenOOD + + Attributes: + input_type (OODCriterionInputType): Expected input type is dataset. + dim (int): Number of dimensions to retain in the subspace projection. + args_dict (dict): Dictionary containing hyperparameter sweep configurations. + """ + super().__init__() + self.args = config.postprocessor.postprocessor_args + self.args_dict = config.postprocessor.postprocessor_sweep + self.dim = self.args.dim + + def setup(self, net: nn.Module, id_loader, ood_loaders): + if not self.setup_flag: + net.eval() + + with torch.no_grad(): + self.w, self.b = net.get_fc() + logger.info("Extracting in-distribution training features") + feature_id_train = [] + for batch in tqdm(id_loader["train"], desc="Setup: ", position=0, leave=True): + data = batch[0].cuda().float() + _, feature = net(data, return_feature=True) + feature_id_train.append(feature.cpu().numpy()) + feature_id_train = np.concatenate(feature_id_train, axis=0) + logit_id_train = feature_id_train @ self.w.T + self.b + + self.u = -np.matmul(pinv(self.w), self.b) + ec = EmpiricalCovariance(assume_centered=True) + ec.fit(feature_id_train - self.u) + eig_vals, eigen_vectors = np.linalg.eig(ec.covariance_) + self.NS = np.ascontiguousarray( + (eigen_vectors.T[np.argsort(eig_vals * -1)[self.dim :]]).T + ) + + vlogit_id_train = norm(np.matmul(feature_id_train - self.u, self.NS), axis=-1) + self.alpha = logit_id_train.max(axis=-1).mean() / vlogit_id_train.mean() + logger.info("Computed alpha: %.4f", self.alpha) + + self.setup_flag = True + else: + pass + + @torch.no_grad() + def forward(self, net: nn.Module, data: Any): + _, feature_ood = net.forward(data, return_feature=True) + feature_ood = feature_ood.cpu() + logit_ood = feature_ood @ self.w.T + self.b + energy_ood = logsumexp(logit_ood.numpy(), axis=-1) + vlogit_ood = norm(np.matmul(feature_ood.numpy() - self.u, self.NS), axis=-1) * self.alpha + score_ood = -vlogit_ood + energy_ood + return -torch.from_numpy(score_ood) + + def set_hyperparam(self, hyperparam: list): + self.dim = hyperparam[0] + + def get_hyperparam(self): + return self.dim + + +class ODINCriterion(TUOODCriterion): + input_type = OODCriterionInputType.DATASET + + def __init__(self, config) -> None: + """OOD criterion based on the ODIN (Out-of-Distribution Detector for Neural Networks) method. + + This criterion uses temperature scaling and input perturbations to compute OOD scores. + It applies a small perturbation to the input data based on the gradient of the cross-entropy + loss with respect to the input. The confidence score is then calculated using the perturbed + input and temperature-scaled logits. Lower confidence scores indicate higher uncertainty. + + Heavily inspired by the OpenOOD repository: + https://github.com/Jingkang50/OpenOOD + + Attributes: + input_type (OODCriterionInputType): Expected input type is dataset. + temperature (float): Temperature scaling factor for logits. + noise (float): Magnitude of the input perturbation. + input_std (list): Standard deviation values for input normalization. + args_dict (dict): Dictionary containing hyperparameter sweep configurations. + """ + super().__init__() + self.args = config.postprocessor.postprocessor_args + + self.temperature = 1 + self.noise = 0.0014 + try: + self.input_std = [0.2470, 0.2435, 0.2616] # // to change + except KeyError: + self.input_std = [0.5, 0.5, 0.5] + self.args_dict = config.postprocessor.postprocessor_sweep + + def forward(self, net: nn.Module, data: Any): + data.requires_grad = True + output = net(data) + + # Calculating the perturbation we need to add, that is, + # the sign of gradient of cross entropy loss w.r.t. input + criterion = nn.CrossEntropyLoss() + + labels = output.detach().argmax(axis=1) + + # Using temperature scaling + output = output / self.temperature + + loss = criterion(output, labels) + loss.backward() + + # Normalizing the gradient to binary in {0, 1} + gradient = torch.ge(data.grad.detach(), 0) + gradient = (gradient.float() - 0.5) * 2 + + # Scaling values taken from original code + gradient[:, 0] = (gradient[:, 0]) / self.input_std[0] + gradient[:, 1] = (gradient[:, 1]) / self.input_std[1] + gradient[:, 2] = (gradient[:, 2]) / self.input_std[2] + + # Adding small perturbations to images + temp_inputs = torch.add(data.detach(), gradient, alpha=-self.noise) + output = net(temp_inputs) + output = output / self.temperature + + # Calculating the confidence after adding perturbations + nn_output = output.detach() + nn_output = nn_output - nn_output.max(dim=1, keepdims=True).values + nn_output = nn_output.exp() / nn_output.exp().sum(dim=1, keepdims=True) + + conf, _ = nn_output.max(dim=1) + + return -conf + + def set_hyperparam(self, hyperparam: list): + self.temperature = hyperparam[0] + self.noise = hyperparam[1] + + def get_hyperparam(self): + return [self.temperature, self.noise] + + +class KNNCriterion(TUOODCriterion): + """OOD criterion based on the K-Nearest Neighbors (KNN) method. + + This criterion uses a KNN-based approach to compute OOD scores. It builds a feature + bank from the in-distribution training data and calculates the distance of test + samples to their K-th nearest neighbor in the feature space. Lower distances + indicate higher confidence, while higher distances indicate greater uncertainty. + + Heavily inspired by the OpenOOD repository: + https://github.com/Jingkang50/OpenOOD + + Attributes: + input_type (OODCriterionInputType): Expected input type is dataset. + K (int): Number of nearest neighbors to consider. + activation_log (np.ndarray): Log of activations from the in-distribution training data. + args_dict (dict): Dictionary containing hyperparameter sweep configurations. + """ + + input_type = OODCriterionInputType.DATASET + + def __init__(self, config) -> None: + super().__init__() + self.args = config.postprocessor.postprocessor_args + self.K = self.args.K + self.activation_log = None + self.args_dict = config.postprocessor.postprocessor_sweep + + def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict): + if not self.setup_flag: + activation_log = [] + net.eval() + with torch.no_grad(): + for batch in tqdm(id_loader_dict["train"], desc="Setup: ", position=0, leave=True): + data = batch[0].cuda().float() + _, feature = net(data, return_feature=True) + activation_log.append(normalizer(feature.data.cpu().numpy())) + + self.activation_log = np.concatenate(activation_log, axis=0) + self.index = faiss.IndexFlatL2(feature.shape[1]) + self.index.add(self.activation_log) + self.setup_flag = True + else: + pass + + @torch.no_grad() + def forward(self, net: nn.Module, data: Any): + _, feature = net(data, return_feature=True) + feature_normed = normalizer(feature.data.cpu().numpy()) + dis, _ = self.index.search( + feature_normed, + self.K, + ) + kth_dist = -dis[:, -1] + return -torch.from_numpy(kth_dist) + + def set_hyperparam(self, hyperparam: list): + self.K = hyperparam[0] + + def get_hyperparam(self): + return self.K + + +class GENCriterion(TUOODCriterion): + input_type = OODCriterionInputType.DATASET + + def __init__(self, config) -> None: + """OOD criterion based on the Generalized Entropy (GEN) method. + + This criterion uses a generalized entropy-based approach to compute OOD scores. + It applies a power transformation to the top-M softmax probabilities and computes + the generalized entropy score. Lower scores indicate higher uncertainty. + + Heavily inspired by the OpenOOD repository: + https://github.com/Jingkang50/OpenOOD + + Attributes: + input_type (OODCriterionInputType): Expected input type is dataset. + gamma (float): Power transformation parameter for generalized entropy. + m (int): Number of top-M probabilities considered for the computation. + args_dict (dict): Dictionary containing hyperparameter sweep configurations. + """ + super().__init__() + self.args = config.postprocessor.postprocessor_args + self.gamma = self.args.gamma + self.m = self.args.m + self.args_dict = config.postprocessor.postprocessor_sweep + + @torch.no_grad() + def forward(self, net: nn.Module, data: Any): + output = net(data) + score = torch.softmax(output, dim=1) + conf = self.generalized_entropy(score, self.gamma, self.m) + return -conf + + def set_hyperparam(self, hyperparam: list): + self.gamma = hyperparam[0] + self.m = hyperparam[1] + + def get_hyperparam(self): + return [self.gamma, self.m] + + def generalized_entropy(self, softmax_id_val, gamma=0.1, m=100): + probs = softmax_id_val + probs_sorted = torch.sort(probs, dim=1)[0][:, -m:] + scores = torch.sum(probs_sorted**gamma * (1 - probs_sorted) ** (gamma), dim=1) + + return -scores + + +def knn_score(bankfeas, queryfeas, k=100, use_min=False): + bankfeas = deepcopy(np.array(bankfeas)) + queryfeas = deepcopy(np.array(queryfeas)) + + index = faiss.IndexFlatIP(bankfeas.shape[-1]) + index.add(bankfeas) + dist, _ = index.search(queryfeas, k) + return np.array(dist.use_min(axis=1)) if use_min else np.array(dist.mean(axis=1)) + + +class NNGuideCriterion(TUOODCriterion): + """NNGuideCriterion is a criterion for out-of-distribution (OOD) detection + that utilizes nearest neighbor guidance based on features and logits + extracted from a neural network. This class is heavily inspired by the + OpenOOD repository: https://github.com/Jingkang50/OpenOOD. + + Attributes: + input_type (OODCriterionInputType): Specifies the type of input for the criterion. + args (Namespace): Arguments related to the postprocessor configuration. + K (int): Number of nearest neighbors to consider for the k-NN score. + alpha (float): Fraction of the in-distribution training data to use for setup. + activation_log (Any): Placeholder for activation logs (currently unused). + args_dict (dict): Dictionary of postprocessor sweep arguments. + setup_flag (bool): Indicates whether the setup process has been completed. + bank_guide (np.ndarray): Precomputed guidance bank combining features and confidence scores. + + Methods: + setup(net, id_loader_dict, ood_loader_dict): + Prepares the guidance bank using in-distribution training data. + + forward(net, data): + Computes the OOD score for the given data using the guidance bank. + + set_hyperparam(hyperparam): + Sets the hyperparameters K and alpha. + + get_hyperparam(): + Retrieves the current hyperparameters K and alpha. + """ + + input_type = OODCriterionInputType.DATASET + + def __init__(self, config) -> None: + super().__init__() + self.args = config.postprocessor.postprocessor_args + self.K = self.args.K + self.alpha = self.args.alpha + self.activation_log = None + self.args_dict = config.postprocessor.postprocessor_sweep + + def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict): + if not self.setup_flag: + net.eval() + bank_feas = [] + bank_logits = [] + with torch.no_grad(): + for batch in tqdm(id_loader_dict["train"], desc="Setup: ", position=0, leave=True): + data = batch[0].cuda().float() + + logit, feature = net(data, return_feature=True) + bank_feas.append(normalizer(feature.data.cpu().numpy())) + bank_logits.append(logit.data.cpu().numpy()) + if len(bank_feas) * id_loader_dict["train"].batch_size > int( + len(id_loader_dict["train"].dataset) * self.alpha + ): + break + + bank_feas = np.concatenate(bank_feas, axis=0) + bank_confs = logsumexp(np.concatenate(bank_logits, axis=0), axis=-1) + self.bank_guide = bank_feas * bank_confs[:, None] + + self.setup_flag = True + else: + pass + + @torch.no_grad() + def forward(self, net: nn.Module, data: Any): + logit, feature = net(data, return_feature=True) + feas_norm = normalizer(feature.data.cpu().numpy()) + energy = logsumexp(logit.data.cpu().numpy(), axis=-1) + + conf = knn_score(self.bank_guide, feas_norm, k=self.K) + score = conf * energy + + return -torch.from_numpy(score) + + def set_hyperparam(self, hyperparam: list): + self.K = hyperparam[0] + self.alpha = hyperparam[1] + + def get_hyperparam(self): + return [self.K, self.alpha] + + +def get_ood_criterion(ood_criterion): + """Get an OOD criterion instance based on a string identifier or class type. + + Args: + ood_criterion (str or type): A string identifier for a predefined OOD criterion + or a subclass of `TUOODCriterion`. + + Returns: + TUOODCriterion: An instance of the requested OOD criterion. + + Raises: + ValueError: If the input string or class type is invalid. + """ + config_dir = Path(__file__).parent / "configs" + if isinstance(ood_criterion, str): + if ood_criterion not in [ + "logit", + "energy", + "msp", + "entropy", + "mutual_information", + "variation_ratio", + ]: + config_path = config_dir / f"{ood_criterion}.yml" + if not config_path.is_file(): + raise ValueError( + f"No configuration file found for OOD criterion '{ood_criterion}'. " + f"Expected {config_path!r}." + ) + config = load_config(str(config_path)) + if ood_criterion == "logit": + return MaxLogitCriterion() + if ood_criterion == "energy": + return EnergyCriterion() + if ood_criterion == "msp": + return MaxSoftmaxCriterion() + if ood_criterion == "entropy": + return EntropyCriterion() + if ood_criterion == "mutual_information": + return MutualInformationCriterion() + if ood_criterion == "variation_ratio": + return VariationRatioCriterion() + if ood_criterion == "scale": + return ScaleCriterion(config) + if ood_criterion == "ash": + return ASHCriterion(config) + if ood_criterion == "react": + return ReactCriterion(config) + if ood_criterion == "adascale_a": + return AdaScaleCriterion(config) + if ood_criterion == "vim": + return VIMCriterion(config) + if ood_criterion == "odin": + return ODINCriterion(config) + if ood_criterion == "knn": + return KNNCriterion(config) + if ood_criterion == "gen": + return GENCriterion(config) + if ood_criterion == "nnguide": + return NNGuideCriterion(config) + raise ValueError( + "The OOD criterion must be one of 'msp', 'logit', 'energy', 'entropy'," + f" 'mutual_information' or 'variation_ratio'. Got {ood_criterion}." + ) + if isinstance(ood_criterion, type) and issubclass(ood_criterion, TUOODCriterion): + return ood_criterion() + raise ValueError( + f"The OOD criterion should be a string or a subclass of TUOODCriterion. Got {type(ood_criterion)}." + ) diff --git a/torch_uncertainty/ood/utils.py b/torch_uncertainty/ood/utils.py new file mode 100644 index 00000000..9b17de5b --- /dev/null +++ b/torch_uncertainty/ood/utils.py @@ -0,0 +1,55 @@ +from collections.abc import Iterator +from pathlib import Path +from typing import Any + +import yaml + + +class ConfigNamespace: + """Wrap a dict so you get BOTH attribute access (cfg.foo) + and a dict API (cfg.keys(), cfg['foo'], cfg.get('foo')). + """ + + def __init__(self, d: dict): + for k, v in d.items(): + setattr(self, k, _to_ns(v)) + + # dict-style + def keys(self) -> Iterator[str]: + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def get(self, key, default=None): + return self.__dict__.get(key, default) + + def __getitem__(self, key): + """Allow dict-style access: cfg[key].""" + return self.__dict__[key] + + def __repr__(self): + """Return the canonical string representation.""" + return f"ConfigNamespace({self.__dict__!r})" + + +def _to_ns(obj: Any) -> Any: + if isinstance(obj, dict): + return ConfigNamespace(obj) + if isinstance(obj, list): + return [_to_ns(v) for v in obj] + return obj + + +def load_config(path: str) -> ConfigNamespace: + """Load any YAML file into a ConfigNamespace. + + You can then do cfg.foo.bar, cfg.foo.keys(), cfg['foo'], etc. + """ + path = Path(path) + with path.open() as f: + raw = yaml.safe_load(f) + return _to_ns(raw) diff --git a/torch_uncertainty/ood_criteria.py b/torch_uncertainty/ood_criteria.py deleted file mode 100644 index e4a202e3..00000000 --- a/torch_uncertainty/ood_criteria.py +++ /dev/null @@ -1,293 +0,0 @@ -from abc import ABC, abstractmethod -from enum import Enum - -import torch -from torch import Tensor, nn - -from torch_uncertainty.metrics import MutualInformation, VariationRatio - - -class OODCriterionInputType(Enum): - """Enum representing the type of input expected by the OOD (Out-of-Distribution) criteria. - - Attributes: - LOGIT (int): The input of the OOD Criterion is in the form of logits (pre-softmax values). - PROB (int): The input is in the form of probabilities (post-softmax values), also called - likelihoods. - ESTIMATOR_PROB (int): The input is in the form of estimated probabilities from an ensemble - or another probabilistic model. - POST_PROCESSING (int): The input is the prediction score given by the post-processing - method. - """ - - LOGIT = 1 - PROB = 2 - ESTIMATOR_PROB = 3 - POST_PROCESSING = 4 - - -class TUOODCriterion(ABC, nn.Module): - input_type: OODCriterionInputType - single_only = False - ensemble_only = False - - def __init__(self) -> None: - """Abstract base class for Out-of-Distribution (OOD) criteria. - - This class defines a common interface for implementing various OOD detection - criteria. Subclasses must implement the `forward` method. - - Attributes: - input_type (OODCriterionInputType): Type of input expected by the criterion. - ensemble_only (bool): Whether the criterion requires ensemble outputs. - """ - super().__init__() - - @abstractmethod - def forward(self, inputs: Tensor) -> Tensor: - """Forward pass for the OOD criterion. - - Args: - inputs (Tensor): The input tensor representing model outputs. - - Returns: - Tensor: OOD score computed according to the criterion. - """ - - -class MaxLogitCriterion(TUOODCriterion): - single_only = True - input_type = OODCriterionInputType.LOGIT - - def __init__(self) -> None: - """OOD criterion based on the maximum logit value. - - This criterion computes the negative of the highest logit value across - the output dimensions. Lower maximum logits indicate greater uncertainty. - - Attributes: - input_type (OODCriterionInputType): Expected input type is logits. - """ - super().__init__() - - def forward(self, inputs: Tensor) -> Tensor: - """Compute the negative of the maximum logit value. - - Args: - inputs (Tensor): Tensor of logits with shape (batch_size, num_classes). - - Returns: - Tensor: Negative of the maximum logit value for each sample. - """ - return -inputs.mean(dim=1).max(dim=-1).values - - -class EnergyCriterion(TUOODCriterion): - single_only = True - input_type = OODCriterionInputType.LOGIT - - def __init__(self) -> None: - r"""OOD criterion based on the energy function. - - This criterion computes the negative log-sum-exp of the logits. - Higher energy values indicate greater uncertainty. - - .. math:: - E(\mathbf{z}) = -\log\left(\sum_{i=1}^{C} \exp(z_i)\right) - - where :math:`\mathbf{z} = [z_1, z_2, \dots, z_C]` is the logit vector. - - Attributes: - input_type (OODCriterionInputType): Expected input type is logits. - """ - super().__init__() - - def forward(self, inputs: Tensor) -> Tensor: - """Compute the negative energy score. - - Args: - inputs (Tensor): Tensor of logits with shape (batch_size, num_classes). - - Returns: - Tensor: Negative energy score for each sample. - """ - return -inputs.mean(dim=1).logsumexp(dim=-1) - - -class MaxSoftmaxCriterion(TUOODCriterion): - input_type = OODCriterionInputType.PROB - - def __init__(self) -> None: - r"""OOD criterion based on maximum softmax probability. - - This criterion computes the negative of the highest softmax probability. - Lower maximum probabilities indicate greater uncertainty. Probabilities are also called* - likelihoods in a more formal context. - - .. math:: - \text{score} = -\max_{i}(p_i) - - where :math:`\mathbf{p} = [p_1, p_2, \dots, p_C]` is the probability vector. - - Attributes: - input_type (OODCriterionInputType): Expected input type is probabilities. - """ - super().__init__() - - def forward(self, inputs: Tensor) -> Tensor: - """Compute the negative of the maximum softmax probability. - - Args: - inputs (Tensor): Tensor of probabilities with shape (batch_size, num_classes). - - Returns: - Tensor: Negative of the highest softmax probability for each sample. - """ - return -inputs.max(-1)[0] - - -class PostProcessingCriterion(MaxSoftmaxCriterion): - input_type = OODCriterionInputType.POST_PROCESSING - - -class EntropyCriterion(TUOODCriterion): - input_type = OODCriterionInputType.ESTIMATOR_PROB - - def __init__(self) -> None: - r"""OOD criterion based on entropy. - - This criterion computes the mean entropy of the predicted probability distribution. - Higher entropy values indicate greater uncertainty. - - .. math:: - H(\mathbf{p}) = -\sum_{i=1}^{C} p_i \log(p_i) - - where :math:`\mathbf{p} = [p_1, p_2, \dots, p_C]` is the probability vector. - - Attributes: - input_type (OODCriterionInputType): Expected input type is estimated probabilities. - """ - super().__init__() - - def forward(self, inputs: Tensor) -> Tensor: - """Compute the entropy of the predicted probability distribution. - - Args: - inputs (Tensor): Tensor of estimated probabilities with shape (batch_size, num_classes). - - Returns: - Tensor: Mean entropy value for each sample. - """ - return torch.special.entr(inputs).sum(dim=-1).mean(dim=1) - - -class MutualInformationCriterion(TUOODCriterion): - ensemble_only = True - input_type = OODCriterionInputType.ESTIMATOR_PROB - - def __init__(self) -> None: - r"""OOD criterion based on mutual information. - - This criterion computes the mutual information between ensemble predictions. - Higher mutual information values indicate lower uncertainty. - - Given ensemble predictions :math:`\{\mathbf{p}^{(k)}\}_{k=1}^{K}`, the mutual information is computed as: - - .. math:: - I(y, \theta) = H\Big(\frac{1}{K}\sum_{k=1}^{K} \mathbf{p}^{(k)}\Big) - \frac{1}{K}\sum_{k=1}^{K} H(\mathbf{p}^{(k)}) - - Attributes: - ensemble_only (bool): Requires ensemble predictions. - input_type (OODCriterionInputType): Expected input type is estimated probabilities. - """ - super().__init__() - self.mi_metric = MutualInformation(reduction="none") - - def forward(self, inputs: Tensor) -> Tensor: - """Compute mutual information from ensemble predictions. - - Args: - inputs (Tensor): Tensor of ensemble probabilities with shape - (ensemble_size, batch_size, num_classes). - - Returns: - Tensor: Mutual information for each sample. - """ - return self.mi_metric(inputs) - - -class VariationRatioCriterion(TUOODCriterion): - ensemble_only = True - input_type = OODCriterionInputType.ESTIMATOR_PROB - - def __init__(self) -> None: - r"""OOD criterion based on variation ratio. - - This criterion computes the variation ratio from ensemble predictions. - Higher variation ratio values indicate greater uncertainty. - - Given ensemble predictions where :math:`n_{\text{mode}}` is the count of the most frequently - predicted class among :math:`K` predictions, the variation ratio is computed as: - - .. math:: - \text{VR} = 1 - \frac{n_{\text{mode}}}{K} - - Attributes: - ensemble_only (bool): Requires ensemble predictions. - input_type (OODCriterionInputType): Expected input type is estimated probabilities. - """ - super().__init__() - self.vr_metric = VariationRatio(reduction="none", probabilistic=False) - - def forward(self, inputs: Tensor) -> Tensor: - """Compute variation ratio from ensemble predictions. - - Args: - inputs (Tensor): Tensor of ensemble probabilities with shape - (ensemble_size, batch_size, num_classes). - - Returns: - Tensor: Variation ratio for each sample. - """ - return self.vr_metric(inputs.transpose(0, 1)) - - -def get_ood_criterion(ood_criterion: type[TUOODCriterion] | str) -> TUOODCriterion: - """Get an OOD criterion instance based on a string identifier or class type. - - Args: - ood_criterion (str or type): A string identifier for a predefined OOD criterion - or a subclass of `TUOODCriterion`. - - Returns: - TUOODCriterion: An instance of the requested OOD criterion. - - Raises: - ValueError: If the input string or class type is invalid. - """ - if isinstance(ood_criterion, str): - if ood_criterion == "logit": - return MaxLogitCriterion() - if ood_criterion == "energy": - return EnergyCriterion() - if ood_criterion == "msp": - return MaxSoftmaxCriterion() - if ood_criterion == "post_processing": - return PostProcessingCriterion() - if ood_criterion == "entropy": - return EntropyCriterion() - if ood_criterion == "mutual_information": - return MutualInformationCriterion() - if ood_criterion == "variation_ratio": - return VariationRatioCriterion() - raise ValueError( - "The OOD criterion must be one of 'msp', 'logit', 'energy', 'entropy'," - f" 'mutual_information' or 'variation_ratio'. Got {ood_criterion}." - ) - if isinstance(ood_criterion, type): - return ood_criterion() - if isinstance(ood_criterion, TUOODCriterion): - return ood_criterion - raise ValueError( - f"The OOD criterion should be a string or a subclass of TUOODCriterion. Got {ood_criterion}." - ) diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index 20c14124..824d7bb8 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -3,6 +3,8 @@ from typing import Literal import torch +import torch.nn.functional as F +from einops import rearrange from torch import Tensor, nn from torch.optim import LBFGS from torch.utils.data import DataLoader @@ -73,14 +75,14 @@ def fit( logging.warning( "model is None. Fitting post_processing method on the dataloader's data directly." ) - self.model = nn.Identity() all_logits = [] all_labels = [] with torch.no_grad(): for inputs, labels in tqdm(dataloader, disable=not progress): logits = self.model(inputs.to(self.device)) - all_logits.append(logits) + log_probs = self._ensemble_log_probs(logits, batch_size=inputs.size(0)) + all_logits.append(log_probs) all_labels.append(labels) all_logits = torch.cat(all_logits).to(self.device) all_labels = torch.cat(all_labels).to(self.device) @@ -120,7 +122,14 @@ def forward(self, inputs: Tensor) -> Tensor: ) return self._scale(self.model(inputs)) - @abstractmethod + def _ensemble_probs(self, logits: Tensor, batch_size: int) -> Tensor: + logits = rearrange(logits, "(m b) c -> b m c", b=batch_size) + return F.softmax(logits, dim=-1).mean(dim=1) + + def _ensemble_log_probs(self, logits: Tensor, batch_size: int) -> Tensor: + probs = self._ensemble_probs(logits, batch_size) + return torch.log(probs) + def _scale(self, logits: Tensor) -> Tensor: """Scale the logits with the optimal temperature. @@ -130,7 +139,6 @@ def _scale(self, logits: Tensor) -> Tensor: Returns: Tensor: Scaled logits. """ - ... def fit_predict( self, diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index 9aaaf81d..b485e436 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -83,4 +83,6 @@ def forward( self, inputs: Tensor, ) -> Tensor: - return self.la(inputs, pred_type=self.pred_type, link_approx=self.link_approx) + return self.la(inputs, pred_type=self.pred_type, link_approx=self.link_approx).to( + inputs.device + ) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index ce7c7a52..7acaca18 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -1,12 +1,16 @@ +import itertools +import logging from collections.abc import Callable from pathlib import Path +import numpy as np import torch import torch.nn.functional as F from einops import rearrange from lightning.pytorch import LightningModule from lightning.pytorch.loggers import Logger from lightning.pytorch.utilities.types import STEP_OUTPUT +from sklearn.metrics import roc_auc_score from timm.data import Mixup as timm_Mixup from torch import Tensor, nn from torch.optim import Optimizer @@ -39,9 +43,8 @@ EPOCH_UPDATE_MODEL, STEP_UPDATE_MODEL, ) -from torch_uncertainty.ood_criteria import ( +from torch_uncertainty.ood.ood_criteria import ( OODCriterionInputType, - PostProcessingCriterion, TUOODCriterion, get_ood_criterion, ) @@ -55,6 +58,14 @@ ) from torch_uncertainty.utils import csv_writer, plot_hist +logger = logging.getLogger(__name__) + +logging.basicConfig( + level=logging.INFO, + format="%(message)s", +) + + MIXUP_PARAMS = { "mixtype": "erm", "mixmode": "elem", @@ -261,16 +272,20 @@ def _init_metrics(self) -> None: self.test_id_entropy = Entropy() if self.eval_ood: - ood_metrics = MetricCollection( + self.ood_metrics_template = MetricCollection( { "AUROC": BinaryAUROC(), "AUPR": BinaryAveragePrecision(), "FPR95": FPR95(pos_label=1), - }, - compute_groups=[["AUROC", "AUPR"], ["FPR95"]], + } ) - self.test_ood_metrics = ood_metrics.clone(prefix="ood/") - self.test_ood_entropy = Entropy() + + if self.is_ensemble: + self.test_ood_ens_metrics_near = {} + self.test_ood_ens_metrics_far = {} + else: + self.test_ood_metrics_near = {} + self.test_ood_metrics_far = {} if self.eval_shift: self.test_shift_metrics = cls_metrics.clone(prefix="shift/") @@ -287,9 +302,6 @@ def _init_metrics(self) -> None: self.test_id_ens_metrics = ens_metrics.clone(prefix="test/ens_") - if self.eval_ood: - self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") - if self.eval_shift: self.test_shift_ens_metrics = ens_metrics.clone(prefix="shift/ens_") @@ -376,7 +388,92 @@ def _apply_mixup(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe - def on_train_start(self) -> None: # coverage: ignore + def setup(self, stage: str) -> None: + super().setup(stage) + + if stage == "test" and self.eval_ood and not self.ood_criterion.setup_flag: + self.trainer.datamodule.setup(stage="fit") + dm = self.trainer.datamodule + id_loader = {} + try: + train_loader = dm.train_dataloader() + except (AttributeError, RuntimeError): + red = "\033[31m" + reset = "\033[0m" + + logger.info( + "%sNo train loader detected, you are probably using ImageNet and need to download the training split manually. If the OOD criteria chosen rely on the train loader the code will fail.%s", + red, + reset, + ) + else: + id_loader["train"] = train_loader + id_loader["val"] = dm.val_dataloader() + self.ood_criterion.setup(self.model, id_loader, None) + self._hyperparam_search_ood() + + def _hyperparam_search_ood(self): + crit: TUOODCriterion = self.ood_criterion + # nothing to do if criterion has no grid or already done + if not hasattr(crit, "args_dict") or crit.hyperparam_search_done: + return + + names = list(crit.args_dict.keys()) + values = [crit.args_dict[n] for n in names] + combos = list(itertools.product(*values)) + + id_val = self.trainer.datamodule.val_dataloader() + ood_val = self.trainer.datamodule.test_dataloader()[1] + + best_auc = -float("inf") + best_combo = None + + logger.info("Starting hyperparameter search for selected OOD eval method...") + for combo in combos: + crit.set_hyperparam(list(combo)) + + # collect scores & binary labels (0 for ID, 1 for OOD) + all_scores = [] + all_labels = [] + + with torch.no_grad(): + # ID val + for x, _ in id_val: + x = x.to(self.device) + + with torch.inference_mode(False), torch.enable_grad(): + x_input = x.detach().clone().requires_grad_(True) + s = crit(self.model, x_input).cpu().numpy() + + all_scores.append(s) + all_labels.append(np.zeros_like(s)) + + # OODval splits + for x, _ in ood_val: + x = x.to(self.device) + + with torch.inference_mode(False), torch.enable_grad(): + x_input = x.detach().clone().requires_grad_(True) + s = crit(self.model, x_input).cpu().numpy() + + all_scores.append(s) + all_labels.append(np.ones_like(s)) + + scores = np.concatenate(all_scores).ravel() + labels = np.concatenate(all_labels).ravel() + auc = roc_auc_score(labels, scores) + + logger.info("Tried %s → VAL AUROC = %.4f", dict(zip(names, combo, strict=False)), auc) + if auc > best_auc: + best_auc, best_combo = auc, combo + + crit.set_hyperparam(list(best_combo)) + crit.hyperparam_search_done = True + logger.info( + "✓ Selected %s with AUROC=%.4f", dict(zip(names, best_combo, strict=False)), best_auc + ) + + def on_train_start(self) -> None: """Put the hyperparameters in tensorboard.""" if self.loss is None: raise ValueError( @@ -409,7 +506,12 @@ def on_test_start(self) -> None: if self.eval_ood and self.log_plots and isinstance(self.logger, Logger): self.id_score_storage = [] - self.ood_score_storage = [] + self.ood_score_storage = { + ds.dataset_name: [] + for ds in itertools.chain( + self.trainer.datamodule.near_oods, self.trainer.datamodule.far_oods + ) + } if hasattr(self.model, "need_bn_update"): self.model.bn_update(self.trainer.train_dataloader, device=self.device) @@ -500,17 +602,37 @@ def test_step( batch_idx: int, dataloader_idx: int = 0, ) -> None: - """Perform a single test step based on the input tensors. + # skip non necessary test loaders + indices = self.trainer.datamodule.get_indices() + + if self.eval_ood and dataloader_idx == indices.get("val_ood"): + return + + if not self.eval_ood and dataloader_idx in indices.get("near_oods", []) + indices.get( + "far_oods", [] + ): + return + + if not self.eval_shift and dataloader_idx in indices.get("shift", []): + return + + if not self.eval_ood: + near_ood_indices = indices.get("near_oods", []) + far_ood_indices = indices.get("far_oods", []) + if near_ood_indices or far_ood_indices: + logger.info( + "You set `eval_ood` to `True` in the datamodule and not in the routine. " + "You should remove it from the datamodule to avoid unnecessary overhead." + ) - Compute the prediction of the model and the value of the metrics on the test batch. Also - handle OOD and distribution-shifted images. + if not self.eval_shift: + shift_indices = indices.get("shift", []) + if shift_indices: + logger.info( + "You set `eval_shift` to `True` in the datamodule and not in the routine. " + "You should remove it from the datamodule to avoid unnecessary overhead." + ) - Args: - batch (tuple[Tensor, Tensor]): the test data and their corresponding targets. - batch_idx (int): the number of the current batch (unused). - dataloader_idx (int): 0 if in-distribution, 1 if out-of-distribution and 2 if - distribution-shifted. - """ inputs, targets = batch if self.test_num_flops is None: @@ -539,11 +661,14 @@ def test_step( ood_scores = self.ood_criterion(probs) elif self.ood_criterion.input_type == OODCriterionInputType.ESTIMATOR_PROB: ood_scores = self.ood_criterion(probs_per_est) - elif self.ood_criterion.input_type == OODCriterionInputType.POST_PROCESSING: - ood_scores = self.ood_criterion(pp_probs) + elif self.ood_criterion.input_type == OODCriterionInputType.DATASET: + with torch.inference_mode(False), torch.enable_grad(): + x = inputs.detach().clone().requires_grad_(True) + ood_scores = self.ood_criterion(self.model, x).to(self.device) + + indices = self.trainer.datamodule.get_indices() if dataloader_idx == 0: - # squeeze if binary classification only for binary metrics self.test_cls_metrics.update( probs.squeeze(-1) if self.binary_cls else probs, targets, @@ -553,29 +678,108 @@ def test_step( if self.eval_grouping_loss: self.test_grouping_loss.update(probs, targets, self.features) + if self.id_score_storage is not None: + self.id_score_storage.append(-ood_scores.detach().cpu()) + + self.log_dict(self.test_cls_metrics, on_epoch=True, add_dataloader_idx=False) + self.test_id_entropy(probs) + self.log( + "test/cls/Entropy", + self.test_id_entropy, + on_epoch=True, + add_dataloader_idx=False, + ) + if self.is_ensemble: self.test_id_ens_metrics.update(probs_per_est) - if self.eval_ood: - self.test_ood_entropy.update(probs) - self.test_ood_metrics.update(ood_scores, torch.zeros_like(targets)) - - if self.id_score_storage is not None: - self.id_score_storage.append(ood_scores.detach().cpu()) - if self.post_processing is not None: + pp_logits = self.post_processing(inputs) + pp_probs = ( + F.softmax(pp_logits, dim=-1) + if not isinstance(self.post_processing, LaplaceApprox) + else pp_logits + ) self.post_cls_metrics.update(pp_probs, targets) if self.eval_ood and dataloader_idx == 1: - self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) - + for ds in self.trainer.datamodule.near_oods: + ds_name = ds.dataset_name + if self.is_ensemble: + if ds_name not in self.test_ood_ens_metrics_near: + self.test_ood_ens_metrics_near[ds_name] = self.ood_metrics_template.clone( + prefix=f"ood_near_{ds_name}_" + ) + self.test_ood_ens_metrics_near[ds_name].update( + ood_scores, torch.zeros_like(targets) + ) + else: + if ds_name not in self.test_ood_metrics_near: + self.test_ood_metrics_near[ds_name] = self.ood_metrics_template.clone( + prefix=f"ood_near_{ds_name}_" + ) + self.test_ood_metrics_near[ds_name].update( + ood_scores, torch.zeros_like(targets) + ) + + for ds in self.trainer.datamodule.far_oods: + ds_name = ds.dataset_name + if self.is_ensemble: + if ds_name not in self.test_ood_ens_metrics_far: + self.test_ood_ens_metrics_far[ds_name] = self.ood_metrics_template.clone( + prefix=f"ood_far_{ds_name}_" + ) + self.test_ood_ens_metrics_far[ds_name].update( + ood_scores, torch.zeros_like(targets) + ) + else: + if ds_name not in self.test_ood_metrics_far: + self.test_ood_metrics_far[ds_name] = self.ood_metrics_template.clone( + prefix=f"ood_far_{ds_name}_" + ) + self.test_ood_metrics_far[ds_name].update(ood_scores, torch.zeros_like(targets)) + + if self.eval_ood and dataloader_idx in indices.get("near_oods", []): + ds_index = indices["near_oods"].index(dataloader_idx) + ds_name = self.trainer.datamodule.near_oods[ds_index].dataset_name if self.is_ensemble: - self.test_ood_ens_metrics.update(probs_per_est) - - if self.ood_score_storage is not None: - self.ood_score_storage.append(ood_scores.detach().cpu()) - - if self.eval_shift and dataloader_idx == (2 if self.eval_ood else 1): + if ds_name not in self.test_ood_ens_metrics_near: + self.test_ood_ens_metrics_near[ds_name] = self.ood_metrics_template.clone( + prefix=f"ood_near_{ds_name}_" + ) + self.test_ood_ens_metrics_near[ds_name].update(ood_scores, torch.ones_like(targets)) + if self.log_plots: + self.ood_score_storage[ds_name].append(-ood_scores.detach().cpu()) + else: + if ds_name not in self.test_ood_metrics_near: + self.test_ood_metrics_near[ds_name] = self.ood_metrics_template.clone( + prefix=f"ood_near_{ds_name}_" + ) + self.test_ood_metrics_near[ds_name].update(ood_scores, torch.ones_like(targets)) + if self.log_plots: + self.ood_score_storage[ds_name].append(-ood_scores.detach().cpu()) + + if self.eval_ood and dataloader_idx in indices.get("far_oods", []): + ds_index = indices["far_oods"].index(dataloader_idx) + ds_name = self.trainer.datamodule.far_oods[ds_index].dataset_name + if self.is_ensemble: + if ds_name not in self.test_ood_ens_metrics_far: + self.test_ood_ens_metrics_far[ds_name] = self.ood_metrics_template.clone( + prefix=f"ood_far_{ds_name}_" + ) + self.test_ood_ens_metrics_far[ds_name].update(ood_scores, torch.ones_like(targets)) + if self.log_plots: + self.ood_score_storage[ds_name].append(-ood_scores.detach().cpu()) + else: + if ds_name not in self.test_ood_metrics_far: + self.test_ood_metrics_far[ds_name] = self.ood_metrics_template.clone( + prefix=f"ood_far_{ds_name}_" + ) + self.test_ood_metrics_far[ds_name].update(ood_scores, torch.ones_like(targets)) + if self.log_plots: + self.ood_score_storage[ds_name].append(-ood_scores.detach().cpu()) + + if self.eval_shift and dataloader_idx in indices.get("shift", []): self.test_shift_metrics.update(probs, targets) if self.is_ensemble: self.test_shift_ens_metrics.update(probs_per_est) @@ -605,6 +809,8 @@ def on_test_epoch_end(self) -> None: "test/cplx/flops": self.test_num_flops, "test/cplx/params": self.num_params, } + id_metrics = self.test_cls_metrics.compute() + self.log_dict(id_metrics) if self.post_processing is not None: result_dict |= self.post_cls_metrics.compute() @@ -615,12 +821,27 @@ def on_test_epoch_end(self) -> None: if self.is_ensemble: result_dict |= self.test_id_ens_metrics.compute() - if self.eval_ood: - result_dict |= self.test_ood_metrics.compute() | { - "ood/Entropy": self.test_ood_entropy.compute() - } - if self.is_ensemble: - result_dict |= self.test_ood_ens_metrics.compute() + if self.is_ensemble and self.eval_ood: + for near_metrics in self.test_ood_ens_metrics_near.values(): + result_near = near_metrics.compute() + self.log_dict(result_near, sync_dist=True) + result_dict.update(result_near) + + for far_metrics in self.test_ood_ens_metrics_far.values(): + result_far = far_metrics.compute() + self.log_dict(result_far) + result_dict.update(result_far) + + elif self.eval_ood: + for near_metrics in self.test_ood_metrics_near.values(): + result_near = near_metrics.compute() + self.log_dict(result_near, sync_dist=True) + result_dict.update(result_near) + + for far_metrics in self.test_ood_metrics_far.values(): + result_far = far_metrics.compute() + self.log_dict(result_far) + result_dict.update(result_far) if self.eval_shift: result_dict |= self.test_shift_metrics.compute() | { @@ -651,17 +872,16 @@ def on_test_epoch_end(self) -> None: self.post_cls_metrics["cal/ECE"].plot()[0], ) - # plot histograms of logits and likelihoods - if self.eval_ood: - id_scores = torch.cat(self.id_score_storage, dim=0) - ood_scores = torch.cat(self.ood_score_storage, dim=0) + # plot histograms of ood scores + if isinstance(self.logger, Logger) and self.log_plots and self.eval_ood: + id_scores = torch.cat(self.id_score_storage, dim=0).numpy() + for name, batches in self.ood_score_storage.items(): + ood_scores = torch.cat(batches, dim=0).numpy() - score_fig = plot_hist( - [id_scores, ood_scores], - 20, - "Histogram of the OOD scores", - )[0] - self.logger.experiment.add_figure("OOD Score Histogram", score_fig) + fig_score = plot_hist( + [id_scores, ood_scores], 20, f"OOD Score Histogram ({name})" + )[0] + self.logger.experiment.add_figure(f"OOD Score/{name}", fig_score) # reset metrics self.test_cls_metrics.reset() @@ -672,11 +892,19 @@ def on_test_epoch_end(self) -> None: self.test_grouping_loss.reset() if self.is_ensemble: self.test_id_ens_metrics.reset() + if self.eval_ood: - self.test_ood_metrics.reset() - self.test_ood_entropy.reset() if self.is_ensemble: - self.test_ood_ens_metrics.reset() + for near_metrics in self.test_ood_ens_metrics_near.values(): + near_metrics.reset() + for far_metrics in self.test_ood_ens_metrics_far.values(): + far_metrics.reset() + else: + for near_metrics in self.test_ood_metrics_near.values(): + near_metrics.reset() + for far_metrics in self.test_ood_metrics_far.values(): + far_metrics.reset() + if self.eval_shift: self.test_shift_metrics.reset() if self.is_ensemble: @@ -724,11 +952,6 @@ def _classification_routine_checks( "Logit-based criteria are not implemented for ensembles. Raise an issue if needed." ) - if isinstance(ood_criterion, PostProcessingCriterion) and post_processing is None: - raise ValueError( - "You cannot set ood_criterion=PostProcessingCriterion when post_processing is None." - ) - if is_ensemble and eval_grouping_loss: raise NotImplementedError( "Grouping loss for ensembles is not yet implemented. Raise an issue if needed." diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 9c358766..258e6630 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -31,7 +31,7 @@ EPOCH_UPDATE_MODEL, STEP_UPDATE_MODEL, ) -from torch_uncertainty.ood_criteria import ( +from torch_uncertainty.ood.ood_criteria import ( OODCriterionInputType, TUOODCriterion, get_ood_criterion, diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py index 608c1767..b7f3aac3 100644 --- a/torch_uncertainty/utils/evaluation_loop.py +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -1,10 +1,11 @@ -from collections import OrderedDict +from collections import OrderedDict, defaultdict from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop from lightning.pytorch.trainer.connectors.logger_connector.result import ( _OUT_DICT, ) from rich import get_console +from rich.box import HEAVY_EDGE from rich.console import Group from rich.table import Table from torch import Tensor @@ -72,10 +73,30 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: metric_name = key.split("/")[-1] metrics["cal"].update({metric_name: value}) elif key.startswith("ood"): + # Initialize the ood dict if it isnt already. if "ood" not in metrics: - metrics["ood"] = {} - metric_name = key.split("/")[-1] - metrics["ood"].update({metric_name: value}) + metrics["ood"] = { + "individual": {"near": {}, "far": {}}, + "NearOOD": {"auroc": [], "fpr95": [], "aupr": []}, + "FarOOD": {"auroc": [], "fpr95": [], "aupr": []}, + } + # Here, we expect keys of the form: "ood_{group}_{datasetName}_{metric}" + parts = key.split("_") + group = parts[1].lower() # either "near" or "far" + dataset_name = parts[2] + metric_name = parts[-1].lower() # e.g. "auroc", "fpr95", "aupr" + + # Store individual dataset results grouped by near or far. + if dataset_name not in metrics["ood"]["individual"][group]: + metrics["ood"]["individual"][group][dataset_name] = {} + metrics["ood"]["individual"][group][dataset_name][metric_name] = value + + # Also, add this value to the corresponding average accumulator. + if group == "near": + metrics["ood"]["NearOOD"][metric_name].append(value) + elif group == "far": + metrics["ood"]["FarOOD"][metric_name].append(value) + elif key.startswith("shift"): if "shift" not in metrics: metrics["shift"] = {} @@ -148,12 +169,73 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: tables.append(table) if "ood" in metrics: - table = Table() - table.add_column(first_col_name, justify="center", style="cyan", width=12) - table.add_column("OOD Detection", justify="center", style="magenta", width=25) - ood_metrics = OrderedDict(sorted(metrics["ood"].items())) - for metric_name, value in ood_metrics.items(): - _add_row(table, metric_name, value) + final_ood_results = defaultdict(lambda: {"auroc": None, "fpr95": None, "aupr": None}) + + for key, val in metrics["ood"].items(): + parts = key.split("_") + + if len(parts) == 2: + dataset_name, metric_postfix = parts[0], parts[1].lower() + + if metric_postfix in ["auroc", "fpr95", "aupr"]: + final_ood_results[dataset_name][metric_postfix] = val + + for key in ["NearOOD", "FarOOD"]: + if key in metrics["ood"]: + for key2, val2 in metrics["ood"][key].items(): + final_ood_results[key][key2] = val2 + + table = Table( + title="[bold]OOD Results[/bold]", + box=HEAVY_EDGE, + show_header=True, + show_lines=False, + ) + table.add_column("Dataset", justify="center", style="cyan", width=16) + table.add_column("AUROC", justify="center", style="magenta", width=12) + table.add_column("FPR95", justify="center", style="magenta", width=12) + table.add_column("AUPR", justify="center", style="magenta", width=12) + + def format_val(val): + if val is None: + return "N/A" + # If we have a list, compute the average. + if isinstance(val, list) and len(val) > 0: + val = sum(val) / len(val) + return f"{val.item() * 100:.3f}%" if val is not None else "N/A" + + # First output the Near OOD individual rows. + for dataset_name, m_dict in metrics["ood"]["individual"]["near"].items(): + row_auroc = format_val(m_dict.get("auroc")) + row_fpr95 = format_val(m_dict.get("fpr95")) + row_aupr = format_val(m_dict.get("aupr")) + table.add_row(f"{dataset_name}", row_auroc, row_fpr95, row_aupr) + + # Then add the NearOOD average row. + near_avg = metrics["ood"]["NearOOD"] + table.add_row( + "NearOOD Average", + format_val(near_avg.get("auroc")), + format_val(near_avg.get("fpr95")), + format_val(near_avg.get("aupr")), + ) + + # Next output the Far OOD individual rows. + for dataset_name, m_dict in metrics["ood"]["individual"]["far"].items(): + row_auroc = format_val(m_dict.get("auroc")) + row_fpr95 = format_val(m_dict.get("fpr95")) + row_aupr = format_val(m_dict.get("aupr")) + table.add_row(f"{dataset_name}", row_auroc, row_fpr95, row_aupr) + + # And add the FarOOD average row. + far_avg = metrics["ood"]["FarOOD"] + table.add_row( + "FarOOD Average", + format_val(far_avg.get("auroc")), + format_val(far_avg.get("fpr95")), + format_val(far_avg.get("aupr")), + ) + tables.append(table) if "sc" in metrics: