Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions tests/engines/test_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
"""Test for feature extractor."""

import shutil
from collections.abc import Callable
from pathlib import Path

import numpy as np
import pytest
import torch
import zarr
from click.testing import CliRunner

from tiatoolbox import cli
from tiatoolbox.models import IOSegmentorConfig
from tiatoolbox.models.architecture.vanilla import CNNBackbone, TimmBackbone
from tiatoolbox.models.engine.deep_feature_extractor import DeepFeatureExtractor
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader

ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu()

# -------------------------------------------------------------------------------------
# Engine
# -------------------------------------------------------------------------------------

device = "cuda" if toolbox_env.has_gpu() else "cpu"


def test_feature_extractor_patches(
remote_sample: Callable,
) -> None:
"""Tests DeepFeatureExtractor on image patches."""
extractor = DeepFeatureExtractor(
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
)

sample_image = remote_sample("thumbnail-1k-1k")

inputs = [sample_image, sample_image]

assert not extractor.patch_mode
output = extractor.run(
images=inputs,
return_probabilities=True,
return_labels=False,
device=device,
patch_mode=True,
)

assert 0.48 < np.mean(output["features"][:]) < 0.52

with pytest.raises(
ValueError,
match=r".*output_type: `annotationstore` is not supported "
r"for `DeepFeatureExtractor` engine",
):
_ = extractor.run(
images=inputs,
return_probabilities=True,
return_labels=False,
device=device,
patch_mode=True,
output_type="annotationstore",
)


def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) -> None:
"""Test feature extraction with DeepFeatureExtractor engine."""
save_dir = track_tmp_path / "output"
# # convert to pathlib Path to prevent wsireader complaint
mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs"))

# * test providing pretrained from torch vs pretrained_model.yaml
shutil.rmtree(save_dir, ignore_errors=True) # default output dir test

extractor = DeepFeatureExtractor(batch_size=1, model="fcn-tissue_mask")
output = extractor.run(
images=[mini_wsi_svs],
return_probabilities=False,
return_labels=False,
device=device,
patch_mode=False,
save_dir=track_tmp_path / "wsi_out_check",
batch_size=1,
output_type="zarr",
memory_threshold=1,
)

output_ = zarr.open(output[mini_wsi_svs], mode="r")
assert len(output_["coordinates"].shape) == 2
assert len(output_["features"].shape) == 3


@pytest.mark.parametrize(
"model", [CNNBackbone("resnet50"), TimmBackbone("efficientnet_b0", pretrained=True)]
)
def test_full_inference(
remote_sample: Callable, track_tmp_path: Path, model: Callable
) -> None:
"""Test full inference with CNNBackbone and TimmBackbone models."""
save_dir = track_tmp_path / "output"
# pre-emptive clean up
shutil.rmtree(save_dir, ignore_errors=True) # default output dir test

mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))

ioconfig = IOSegmentorConfig(
input_resolutions=[
{"units": "mpp", "resolution": 0.25},
],
output_resolutions=[
{"units": "mpp", "resolution": 0.25},
],
patch_input_shape=[512, 512],
patch_output_shape=[512, 512],
stride_shape=[256, 256],
save_resolution={"units": "mpp", "resolution": 8.0},
)

extractor = DeepFeatureExtractor(batch_size=4, model=model)
output = extractor.run(
images=[mini_wsi_svs],
device=device,
save_dir=track_tmp_path / "wsi_out_check",
batch_size=4,
output_type="zarr",
ioconfig=ioconfig,
patch_mode=False,
)

output_ = zarr.open(output[mini_wsi_svs], mode="r")

positions = output_["coordinates"]
features = output_["features"]

reader = WSIReader.open(mini_wsi_svs)
patches = [
reader.read_bounds(
positions[patch_idx],
resolution=0.25,
units="mpp",
pad_constant_values=255,
coord_space="resolution",
)
for patch_idx in range(4)
]
patches = np.array(patches)
patches = torch.from_numpy(patches) # NHWC
patches = patches.permute(0, 3, 1, 2).contiguous() # NCHW
patches = patches.to(device).type(torch.float32)
model = extractor.model
# Inference mode
model.eval()
with torch.inference_mode():
_features = model(patches).cpu().numpy()
# ! must maintain same batch size and likely same ordering
# ! else the output values will not exactly be the same (still < 1.0e-4
# ! of epsilon though)
assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1


@pytest.mark.skipif(
toolbox_env.running_on_ci() or not ON_GPU,
reason="Local test on machine with GPU.",
)
def test_multi_gpu_feature_extraction(
remote_sample: Callable, track_tmp_path: Path
) -> None:
"""Local functionality test for feature extraction using multiple GPUs."""
save_dir = track_tmp_path / "output"
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
shutil.rmtree(save_dir, ignore_errors=True)

# Use multiple GPUs
device = select_device(on_gpu=ON_GPU)

wsi_ioconfig = IOSegmentorConfig(
input_resolutions=[{"units": "mpp", "resolution": 0.5}],
patch_input_shape=[224, 224],
output_resolutions=[{"units": "mpp", "resolution": 0.5}],
patch_output_shape=[224, 224],
stride_shape=[224, 224],
)

model = TimmBackbone(backbone="UNI", pretrained=True)
extractor = DeepFeatureExtractor(
model=model,
batch_size=32,
num_workers=4,
)

output = extractor.run(
[mini_wsi_svs],
patch_mode=False,
device=device,
ioconfig=wsi_ioconfig,
save_dir=save_dir,
auto_get_mask=True,
output_type="zarr",
)
output_ = zarr.open(output[mini_wsi_svs], mode="r")

positions = output_["coordinates"]
features = output_["features"]
assert len(positions.shape) == 2
assert len(features.shape) == 2


# -------------------------------------------------------------------------------------
# Command Line Interface
# -------------------------------------------------------------------------------------


def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None:
"""Test for feature extractor CLI single file."""
runner = CliRunner()
models_wsi_result = runner.invoke(
cli.main,
[
"deep-feature-extractor",
"--img-input",
str(sample_svs),
"--patch-mode",
"False",
"--output-path",
str(track_tmp_path / "output"),
],
)

assert models_wsi_result.exit_code == 0
assert (track_tmp_path / "output" / (sample_svs.stem + ".zarr")).exists()
2 changes: 1 addition & 1 deletion tests/engines/test_semantic_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def test_wsi_segmentor_annotationstore(


def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None:
"""Test for models CLI single file."""
"""Test semantic segmentor CLI single file."""
runner = CliRunner()
models_wsi_result = runner.invoke(
cli.main,
Expand Down
2 changes: 2 additions & 0 deletions tiatoolbox/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tiatoolbox import __version__
from tiatoolbox.cli.common import tiatoolbox_cli
from tiatoolbox.cli.deep_feature_extractor import deep_feature_extractor
from tiatoolbox.cli.nucleus_instance_segment import nucleus_instance_segment
from tiatoolbox.cli.patch_predictor import patch_predictor
from tiatoolbox.cli.read_bounds import read_bounds
Expand Down Expand Up @@ -43,6 +44,7 @@ def main() -> click.BaseCommand:
main.add_command(read_bounds)
main.add_command(save_tiles)
main.add_command(semantic_segmentor)
main.add_command(deep_feature_extractor)
main.add_command(slide_info)
main.add_command(slide_thumbnail)
main.add_command(tissue_mask)
Expand Down
113 changes: 113 additions & 0 deletions tiatoolbox/cli/deep_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Command line interface for deep feature extractor."""

from __future__ import annotations

from tiatoolbox.cli.common import (
cli_auto_get_mask,
cli_batch_size,
cli_device,
cli_file_type,
cli_img_input,
cli_masks,
cli_memory_threshold,
cli_model,
cli_num_workers,
cli_output_path,
cli_output_type,
cli_patch_mode,
cli_return_labels,
cli_return_probabilities,
cli_verbose,
cli_weights,
cli_yaml_config_path,
prepare_ioconfig,
prepare_model_cli,
tiatoolbox_cli,
)


@tiatoolbox_cli.command()
@cli_img_input()
@cli_output_path(
usage_help="Output directory where model features will be saved.",
default="deep_feature_extractor",
)
@cli_file_type(
default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs",
)
@cli_model(default="fcn-tissue_mask")
@cli_weights()
@cli_device(default="cpu")
@cli_batch_size(default=1)
@cli_yaml_config_path()
@cli_masks(default=None)
@cli_num_workers(default=0)
@cli_output_type(
default="zarr",
)
@cli_memory_threshold(default=80)
@cli_patch_mode(default=False)
@cli_return_probabilities(default=True)
@cli_return_labels(default=False)
@cli_auto_get_mask(default=True)
@cli_verbose(default=True)
def deep_feature_extractor(
model: str,
weights: str,
img_input: str,
file_types: str,
masks: str | None,
output_path: str,
batch_size: int,
yaml_config_path: str,
num_workers: int,
device: str,
output_type: str,
memory_threshold: int,
*,
patch_mode: bool,
return_probabilities: bool,
return_labels: bool,
auto_get_mask: bool,
verbose: bool,
) -> None:
"""Process a set of input images with a deep feature extractor engine."""
from tiatoolbox.models import ( # noqa: PLC0415
DeepFeatureExtractor,
IOSegmentorConfig,
)

files_all, masks_all, output_path = prepare_model_cli(
img_input=img_input,
output_path=output_path,
masks=masks,
file_types=file_types,
)

ioconfig = prepare_ioconfig(
IOSegmentorConfig,
pretrained_weights=weights,
yaml_config_path=yaml_config_path,
)

extractor = DeepFeatureExtractor(
model=model,
weights=weights,
batch_size=batch_size,
num_workers=num_workers,
verbose=verbose,
)

_ = extractor.run(
images=files_all,
masks=masks_all,
patch_mode=patch_mode,
ioconfig=ioconfig,
device=device,
save_dir=output_path,
output_type=output_type,
return_probabilities=return_probabilities,
return_labels=return_labels,
auto_get_mask=auto_get_mask,
memory_threshold=memory_threshold,
)
2 changes: 2 additions & 0 deletions tiatoolbox/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .architecture.nuclick import NuClick
from .architecture.sccnn import SCCNN
from .dataset import PatchDataset, WSIPatchDataset, WSIStreamDataset
from .engine.deep_feature_extractor import DeepFeatureExtractor
from .engine.io_config import (
IOInstanceSegmentorConfig,
IOPatchPredictorConfig,
Expand All @@ -24,6 +25,7 @@

__all__ = [
"SCCNN",
"DeepFeatureExtractor",
"HoVerNet",
"HoVerNetPlus",
"IDaRS",
Expand Down
Loading