diff --git a/.gitignore b/.gitignore index fedd499d..651bdb0d 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,5 @@ temp/ # VSCode .vscode/ *.zip + +models/lower_pelvic_reg/eval/ diff --git a/models/lower_pelvic_reg/configs/inference.yaml b/models/lower_pelvic_reg/configs/inference.yaml new file mode 100644 index 00000000..c5098895 --- /dev/null +++ b/models/lower_pelvic_reg/configs/inference.yaml @@ -0,0 +1,85 @@ +--- +imports: + - $import matplotlib.pyplot as plt +dataset_dir: "/Users/yiwenli/data/multiorgan_final" +bundle_root: "./" +device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" +output_dir: "$@bundle_root + '/eval'" +ckpt: "$@bundle_root + '/lower_pelvic_reg_cpu_nonparallel-2.pth'" +cross_subjects: false # whether the input images are from the same subject + +dataset: + _target_: "scripts.dataset.RegDataset" + train: false + dataset_dir: "@dataset_dir" + pixdim: [0.75, 0.75, 2.5] + spatial_size: [256, 256, 40] + rotate_range: $np.pi / 36 + translate_range: [20, 20, 4] + scale_range: [0.15, 0.15, 0.15] + +data_loader: + _target_: "torch.utils.data.DataLoader" + dataset: "@dataset" + batch_size: 1 + num_workers: 0 + +# display first pair of data +first_pair: $@dataset[0] +display: + - $plt.subplot(2,2,1) + - $plt.gca().set_title("moving image") + - $plt.gca().axis('off') + - $plt.imshow(np.transpose(@first_pair[0]["image"][0, ..., @first_pair[0]["image"].shape[-1]//2])) + - $plt.subplot(2,2,2) + - $plt.gca().set_title("fixed image") + - $plt.gca().axis('off') + - $plt.imshow(np.transpose(@first_pair[1]["image"][0, ..., @first_pair[0]["image"].shape[-1]//2])) + - $plt.subplot(2,2,3) + - $plt.gca().set_title("moving label") + - $plt.gca().axis('off') + - $plt.imshow(np.transpose(@first_pair[0]["label"][0, ..., @first_pair[0]["image"].shape[-1]//2])) + - $plt.subplot(2,2,4) + - $plt.gca().set_title("fixed label") + - $plt.gca().axis('off') + - $plt.imshow(np.transpose(@first_pair[1]["label"][0, ..., @first_pair[0]["image"].shape[-1]//2])) + - $plt.show() + +network: + _target_: LocalNet + spatial_dims: 3 + in_channels: 2 + out_channels: 3 + num_channel_initial: 32 + extract_levels: [0, 1, 2, 3] + out_kernel_initializer: "zeros" + +handlers: + - _target_: CheckpointLoader + load_path: "@ckpt" + load_dict: {model: "@network"} + +inferer: + _target_: "scripts.inferer.RegistrationInferer" + +evaluator: + _target_: "scripts.evaluator.RegistrationEvaluator" + device: "@device" + val_data_loader: "@data_loader" + network: "@network" + epoch_length: $len(@dataset) // @data_loader#batch_size + inferer: "@inferer" + val_handlers: "@handlers" + postprocessing: + _target_: Compose + transforms: + - _target_: "scripts.visualise.SaveRegd" + keys: ["moving_image", "moving_label", "fixed_image", "fixed_label", "warped_image", "warped_label"] + pixdim: [ 0.75, 0.75, 2.5 ] + spatial_size: [ 256, 256, 40 ] + output_dir: "@output_dir" + +eval: + - $monai.utils.set_determinism(seed=123) + - "$setattr(torch.backends.cudnn, 'benchmark', True)" + - $@evaluator.run() diff --git a/models/lower_pelvic_reg/configs/logging.conf b/models/lower_pelvic_reg/configs/logging.conf new file mode 100644 index 00000000..91c1a21c --- /dev/null +++ b/models/lower_pelvic_reg/configs/logging.conf @@ -0,0 +1,21 @@ +[loggers] +keys=root + +[handlers] +keys=consoleHandler + +[formatters] +keys=fullFormatter + +[logger_root] +level=INFO +handlers=consoleHandler + +[handler_consoleHandler] +class=StreamHandler +level=INFO +formatter=fullFormatter +args=(sys.stdout,) + +[formatter_fullFormatter] +format=%(asctime)s - %(name)s - %(levelname)s - %(message)s diff --git a/models/lower_pelvic_reg/configs/metadata.json b/models/lower_pelvic_reg/configs/metadata.json new file mode 100644 index 00000000..c8a1483d --- /dev/null +++ b/models/lower_pelvic_reg/configs/metadata.json @@ -0,0 +1,64 @@ +{ + "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json", + "version": "0.0.3", + "changelog": { + "0.0.3": "update to use monai 1.1.0", + "0.0.2": "update to use rc1", + "0.0.1": "Initial version" + }, + "monai_version": "1.1.0", + "pytorch_version": "1.13.0", + "numpy_version": "1.22.2", + "optional_packages_version": { + "pytorch-ignite": "0.4.8" + }, + "task": "Spatial transformer for hand image registration from the MedNIST dataset", + "description": "This is an example of a ResNet and spatial transformer for hand xray image registration", + "authors": "MONAI team", + "copyright": "Copyright (c) MONAI Consortium", + "intended_use": "This is an example of image registration using MONAI, suitable for demonstration purposes only.", + "data_type": "jpeg", + "network_data_format": { + "inputs": { + "image": { + "type": "image", + "format": "magnitude", + "num_channels": 2, + "spatial_shape": [ + 64, + 64 + ], + "dtype": "float32", + "value_range": [ + 0, + 1 + ], + "is_patch_data": false, + "channel_def": { + "0": "moving image", + "1": "fixed image" + } + } + }, + "outputs": { + "pred": { + "type": "image", + "format": "magnitude", + "num_channels": 1, + "spatial_shape": [ + 64, + 64 + ], + "dtype": "float32", + "value_range": [ + 0, + 1 + ], + "is_patch_data": false, + "channel_def": { + "0": "image" + } + } + } + } +} diff --git a/models/lower_pelvic_reg/configs/train.yaml b/models/lower_pelvic_reg/configs/train.yaml new file mode 100644 index 00000000..061967de --- /dev/null +++ b/models/lower_pelvic_reg/configs/train.yaml @@ -0,0 +1,217 @@ +--- +imports: + - $import glob + - $import matplotlib.pyplot as plt + +# workflow parameters +bundle_root: "./" +device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" +ckpt_dir: "$@bundle_root + '/models'" # folder to save new checkpoints +ckpt: "" # path to load an existing checkpoint +val_interval: 1 # every epoch +max_epochs: 300 +cross_subjects: false # whether the input images are from the same subject + +# construct the moving and fixed datasets +dataset_dir: "../MedNIST/Hand" +datalist: "$list(sorted(glob.glob(@dataset_dir + '/*.jpeg')))[:7000]" # training with 7000 images +val_datalist: "$list(sorted(glob.glob(@dataset_dir + '/*.jpeg')))[7000:8500]" # validation with 1500 images + +image_load: + - _target_: LoadImage + image_only: True + ensure_channel_first: True + + - _target_: ScaleIntensityRange + a_min: 0.0 + a_max: 255.0 + b_min: 0.0 + b_max: 1.0 + + - _target_: EnsureType + device: "@device" + +image_aug: + - _target_: RandAffine + spatial_size: [64, 64] + translate_range: 5 + scale_range: [-0.15, 0.15] + prob: 1.0 + rotate_range: $np.pi / 8 + mode: bilinear + padding_mode: border + cache_grid: True + device: "@device" + + - _target_: RandGridDistortion + prob: 0.2 + num_cells: 8 + device: "@device" + distort_limit: 0.1 + +preprocessing: + _target_: Compose + transforms: "$@image_load + @image_aug" + +cache_datasets: + - _target_: ShuffleBuffer + data: + _target_: CacheDataset + data: "@datalist" + transform: $@preprocessing.set_random_state(123) + hash_as_key: true + runtime_cache: threads + epochs: "@max_epochs" + seed: "$int(3) if @cross_subjects else int(2)" + - _target_: ShuffleBuffer + data: + _target_: CacheDataset + data: "@datalist" + transform: $@preprocessing.set_random_state(234) + hash_as_key: true + runtime_cache: threads + epochs: "@max_epochs" + seed: 2 + +zip_dataset: + _target_: IterableDataset + data: "$map(lambda t: dict(image=monai.transforms.concatenate(t), label=t[1]), zip(*@cache_datasets))" + +data_loader: + _target_: ThreadDataLoader + dataset: "@zip_dataset" + batch_size: 64 + num_workers: 0 + + +# components for debugging +first_pair: $monai.utils.misc.first(@data_loader) +display: + - $monai.utils.set_determinism(seed=123) + - $print(@first_pair.keys(), @first_pair['image'].meta['filename_or_obj']) + - "$print(@trainer#loss_function(@first_pair['image'][:, 0:1], @first_pair['image'][:, 1:2]))" # print loss + - $plt.subplot(1,2,1) + - $plt.imshow(@first_pair['image'][0, 0], cmap="gray") + - $plt.subplot(1,2,2) + - $plt.imshow(@first_pair['image'][0, 1], cmap="gray") + - $plt.show() + + +# network definition +net: + _target_: scripts.net.RegResNet + image_size: [64, 64] + spatial_dims: 2 + mode: "bilinear" + padding_mode: "border" + +optimizer: + _target_: torch.optim.Adam + params: $@net.parameters() + lr: 0.00001 + +# create a validation evaluator +val: + cache_datasets: + - _target_: ShuffleBuffer + data: + _target_: CacheDataset + data: "@val_datalist" + transform: $@preprocessing.set_random_state(123) + hash_as_key: true + runtime_cache: threads + epochs: -1 # infinite + seed: "$int(3) if @cross_subjects else int(2)" + - _target_: ShuffleBuffer + data: + _target_: CacheDataset + data: "@val_datalist" + transform: $@preprocessing.set_random_state(234) + hash_as_key: true + runtime_cache: threads + epochs: -1 # infinite + seed: 2 + + zip_dataset: + _target_: IterableDataset + data: "$map(lambda t: dict(image=monai.transforms.concatenate(t), label=t[1]), zip(*@val#cache_datasets))" + + data_loader: + _target_: ThreadDataLoader + dataset: "@val#zip_dataset" + batch_size: 64 + num_workers: 0 + + evaluator: + _target_: SupervisedEvaluator + device: "@device" + val_data_loader: "@val#data_loader" + network: "@net" + epoch_length: $len(@val_datalist) // @val#data_loader#batch_size + inferer: "$monai.inferers.SimpleInferer()" + metric_cmp_fn: "$lambda x, y: x < y" + key_val_metric: + val_mse: + _target_: MeanSquaredError + output_transform: "$monai.handlers.from_engine(['pred', 'label'])" + additional_metrics: {"mutual info loss": "@loss_metric#metric_handler"} + val_handlers: + - _target_: StatsHandler + iteration_log: false + - _target_: CheckpointSaver + save_dir: "@ckpt_dir" + save_dict: {model: "@net"} + save_key_metric: true + key_metric_negative_sign: true + # key_metric_filename: "model.pt" + +# training handlers +handlers: + - _target_: StatsHandler + tag_name: "train_loss" + output_transform: "$monai.handlers.from_engine(['loss'], first=True)" + - _target_: ValidationHandler + validator: "@val#evaluator" + epoch_level: true + interval: "@val_interval" + +loss_metric: + metric_handler: + _target_: IgniteMetric + output_transform: "$monai.handlers.from_engine(['pred', 'label'])" + metric_fn: + _target_: LossMetric + loss_fn: "@mutual_info_loss" + get_not_nans: true + +ckpt_loader: + - _target_: CheckpointLoader + load_path: "@ckpt" + load_dict: {model: "@net"} + +lncc_loss: + _target_: LocalNormalizedCrossCorrelationLoss + spatial_dims: 2 + kernel_size: 5 + kernel_type: rectangular + reduction: mean + +mutual_info_loss: + _target_: GlobalMutualInformationLoss + +# create the primary trainer +trainer: + _target_: SupervisedTrainer + device: "@device" + train_data_loader: "@data_loader" + network: "@net" + max_epochs: "@max_epochs" + epoch_length: $len(@datalist) // @data_loader#batch_size + loss_function: "@lncc_loss" + optimizer: "@optimizer" + train_handlers: "$@handlers + @ckpt_loader if @ckpt else @handlers" + +training: + - $monai.utils.set_determinism(seed=23) + - "$setattr(torch.backends.cudnn, 'benchmark', True)" + - $@trainer.run() diff --git a/models/lower_pelvic_reg/docs/README.md b/models/lower_pelvic_reg/docs/README.md new file mode 100644 index 00000000..cb70af13 --- /dev/null +++ b/models/lower_pelvic_reg/docs/README.md @@ -0,0 +1,32 @@ +# MedNIST Hand Image Registration + +## Downloading the Dataset +Download the dataset [from here](https://zenodo.org/record/7013610) and extract the contents to a convenient location. + +The data set includes 589 T2-weighted images acquired from the same number of patients collected by seven studies, +INDEX, the SmartTarget Biopsy Trial, PICTURE, TCIA Prostate3T, Promise12, TCIA ProstateDx (Diagnosis) and the Prostate +MR Image Database. Further details are reported in the respective study references. + +If you find this labelled data set useful for your research please consider to acknowledge the work: Li, Y., et al. +"Prototypical few-shot segmentation for cross-institution male pelvic structures with spatial registration." +arXiv preprint arXiv:2209.05160 (2022). + +## Inference + +## Using other data +To train or inference the model on your own data, organise the file directories as following: +``` +dataset_dir +├── data + ├── $PatientID1$_img.nii + ├── $PatientID1$_mask.nii + ├── $PatientID2$_img.nii + ├── $PatientID2$_mask.nii + ├── ... +``` + +## Visualize the first pair of images for debugging (requires `matplotlib`) +![fixed](./examples/display.png) +```bash +python -m monai.bundle run display --config_file configs/inference.yaml +``` diff --git a/models/lower_pelvic_reg/docs/examples/display.png b/models/lower_pelvic_reg/docs/examples/display.png new file mode 100644 index 00000000..cb926a9e Binary files /dev/null and b/models/lower_pelvic_reg/docs/examples/display.png differ diff --git a/models/lower_pelvic_reg/lower_pelvic_reg.pth b/models/lower_pelvic_reg/lower_pelvic_reg.pth new file mode 100644 index 00000000..0f6c51a6 Binary files /dev/null and b/models/lower_pelvic_reg/lower_pelvic_reg.pth differ diff --git a/models/lower_pelvic_reg/lower_pelvic_reg_cpu-2.pth b/models/lower_pelvic_reg/lower_pelvic_reg_cpu-2.pth new file mode 100644 index 00000000..867c886e Binary files /dev/null and b/models/lower_pelvic_reg/lower_pelvic_reg_cpu-2.pth differ diff --git a/models/lower_pelvic_reg/lower_pelvic_reg_cpu.pth b/models/lower_pelvic_reg/lower_pelvic_reg_cpu.pth new file mode 100644 index 00000000..9da6534e Binary files /dev/null and b/models/lower_pelvic_reg/lower_pelvic_reg_cpu.pth differ diff --git a/models/lower_pelvic_reg/lower_pelvic_reg_cpu_nonparallel-2.pth b/models/lower_pelvic_reg/lower_pelvic_reg_cpu_nonparallel-2.pth new file mode 100644 index 00000000..7f609484 Binary files /dev/null and b/models/lower_pelvic_reg/lower_pelvic_reg_cpu_nonparallel-2.pth differ diff --git a/models/lower_pelvic_reg/requirements.txt b/models/lower_pelvic_reg/requirements.txt new file mode 100644 index 00000000..549695ef --- /dev/null +++ b/models/lower_pelvic_reg/requirements.txt @@ -0,0 +1,9 @@ +protobuf==3.17.0 +monai==0.9.0 +nibabel==3.2.1 +numpy==1.20.3 +pylatex==1.4.1 +torch==1.8.1 +tqdm==4.60.0 +tensorboard==2.5.0 +pyyaml diff --git a/models/lower_pelvic_reg/scripts/dataset.py b/models/lower_pelvic_reg/scripts/dataset.py new file mode 100644 index 00000000..5b53a685 --- /dev/null +++ b/models/lower_pelvic_reg/scripts/dataset.py @@ -0,0 +1,357 @@ +import os +from typing import Optional, Sequence, Tuple, Union + +import numpy as np + +import torch +from monai.transforms import ( + LoadImaged, + AddChanneld, + Spacingd, + NormalizeIntensityd, + ScaleIntensityd, + ToTensord, + RandAffined, + CenterSpatialCropd, + SpatialPadd, + Compose +) + +from torch.nn import functional as F +from torch.utils.data import Dataset + + +def get_institution_patient_dict(dataset_dir, train): + """ + divide images by institution, take 3/4 for training and 1/4 for inference + :param dataset_dir: str + :param train: bool, specify training or not + :return: dict + """ + if os.path.exists(f'{dataset_dir}/institution.txt'): + # divide images by institution + institution_patient_dict = {i: [] for i in range(1, 8)} + with open(f'{dataset_dir}/institution.txt') as f: + patient_ins_list = f.readlines() + for patient_ins in patient_ins_list: + patient, ins = patient_ins[:-1].split(" ") + institution_patient_dict[int(ins)].append(patient) + else: + # if no institution info, consider all patients as from the same institution + patient_list = [p.replace("_mask.nii", "") for p in os.listdir(f'{dataset_dir}/data') if "mask" in p] + institution_patient_dict = {1: patient_list} + + # take 3/4 for training and 1/4 for inference + for k, v in institution_patient_dict.items(): + if train: + institution_patient_dict[k] = v[:-len(v)//4] + else: + institution_patient_dict[k] = v[-len(v)//4:] + return institution_patient_dict + + +def sample_pair(idx, image_list_len): + """ + given query index, sample a support index + :param idx: int, query index + :param image_list_len: int, number of training images + :return: int + """ + out = idx + while out == idx: + out = np.random.randint(image_list_len) + return out + + +class RegDataset(Dataset): + + def __init__(self, + train: bool, + dataset_dir: str, + pixdim: Sequence[float], + spatial_size: Optional[Union[Sequence[int], int]] = None, + rotate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + translate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None + ) -> None: + """ + Args: + train: bool, specify if training or not + dataset_dir: directory storing the t2w images and labels. + pixdim: output voxel spacing. if providing a single number, will use it for the first dimension. + items of the pixdim sequence map to the spatial dimensions of input image, if length + of pixdim sequence is longer than image spatial dimensions, will ignore the longer part, + if shorter, will pad with `1.0`. + if the components of the `pixdim` are non-positive values, the transform will use the + corresponding components of the original pixdim, which is computed from the `affine` + matrix of input image. + spatial_size: output image spatial size. + if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, + the transform will use the spatial size of `image`. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of image size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of image is `64`. + rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then + `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter + for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used. + This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be + in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]` + for dim0 and nothing for the remaining dimensions. + translate_range: translate range with format matching `rotate_range`, it defines the range to randomly + select pixel/voxel to translate for every spatial dims. + scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select + the scale factor to translate for every spatial dims. A value of 1.0 is added to the result. + This allows 0 to correspond to no change (i.e., a scaling of 1.0). + """ + super(RegDataset, self).__init__() + self.train = train + + # divide images by institution, take 3/4 for training and 1/4 for inference + institution_patient_dict = get_institution_patient_dict( + dataset_dir=dataset_dir, + train=train, + ) + self.image_list = [] + for ins, patient_list in institution_patient_dict.items(): + self.image_list.extend([(p, ins) for p in patient_list]) + + # sample inference pairs if not training + if not train: + self.val_pair = [] + # for each query image + for moving_p, moving_ins in self.image_list: + # for each institution + for fixed_ins, patient_list in institution_patient_dict.items(): + while True: + fixed_p = patient_list[np.random.randint(0, len(patient_list))] + if fixed_p != moving_p: + break + self.val_pair.append([(moving_p, moving_ins), (fixed_p, fixed_ins)]) + + # initialise transformation + self.image_loader = LoadImages( + dataset_dir=dataset_dir, + augmentation=train, + spatial_size=spatial_size, + pixdim=pixdim, + rotate_range=rotate_range, + translate_range=translate_range, + scale_range=scale_range + ) + + def __len__(self): + return len(self.image_list) if self.train else len(self.val_pair) + + def __getitem__(self, idx): + if self.train: + moving = idx + fixed = sample_pair(idx, len(self.image_list)) + moving, fixed = self.image_list[moving], self.image_list[fixed] + else: + moving, fixed = self.val_pair[idx] + + moving = self.image_loader(moving) + fixed = self.image_loader(fixed) + + return moving, fixed + + +def get_transform(augmentation, spatial_size, pixdim, rotate_range, translate_range, scale_range): + """ + Args: + pixdim: output voxel spacing. if providing a single number, will use it for the first dimension. + items of the pixdim sequence map to the spatial dimensions of input image, if length + of pixdim sequence is longer than image spatial dimensions, will ignore the longer part, + if shorter, will pad with `1.0`. + if the components of the `pixdim` are non-positive values, the transform will use the + corresponding components of the original pixdim, which is computed from the `affine` + matrix of input image. + spatial_size: output image spatial size. + if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, + the transform will use the spatial size of `image`. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of image size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of image is `64`. + augmentation: bool, specifying apply augmentation or not. + rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then + `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter + for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used. + This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be + in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]` + for dim0 and nothing for the remaining dimensions. + translate_range: translate range with format matching `rotate_range`, it defines the range to randomly + select pixel/voxel to translate for every spatial dims. + scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select + the scale factor to translate for every spatial dims. A value of 1.0 is added to the result. + This allows 0 to correspond to no change (i.e., a scaling of 1.0). + """ + pre_augmentation = [ + LoadImaged(keys=["image", "label"]), + AddChanneld(keys=["image", "label"]), + Spacingd( + keys=["image", "label"], + pixdim=pixdim, + mode=("bilinear", "nearest"), + ), + ] + + post_augmentation = [ + NormalizeIntensityd(keys=["image"]), + ScaleIntensityd(keys=["image"]), + ToTensord(keys=["image", "label"]) + ] + + if augmentation: + middle_transform = [ + RandAffined( + keys=["image", "label"], + spatial_size=spatial_size, + prob=1.0, + rotate_range=(rotate_range, rotate_range, rotate_range), + shear_range=None, + translate_range=translate_range, + scale_range=scale_range, + mode=("bilinear", "nearest"), + padding_mode="zeros", + as_tensor_output=False, + device=torch.device('cpu'), + allow_missing_keys=False + ) + ] + else: + middle_transform = [ + CenterSpatialCropd(keys=["image", "label"], roi_size=spatial_size), + SpatialPadd( + keys=["image", "label"], + spatial_size=spatial_size, + method='symmetric', + mode='constant', + allow_missing_keys=False + ) + ] + + return Compose(pre_augmentation + middle_transform + post_augmentation) + + +class LoadImages: + """ + Transform customised for registration + given a dictionary specifying moving and fixed image names, output a dictionary of dictionaries each containing the + following keys: + "image": tensor of shape (1, ...) the t2w image + "label": tensor of shape (1, ...) the label of the corresponding image + "image_name": the name of the image + """ + + def __init__(self, + dataset_dir: str, + pixdim: Union[Sequence[float], float], + spatial_size: Optional[Union[Sequence[int], int]] = None, + augmentation: bool = False, + rotate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + translate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + ) -> None: + """ + Args: + dataset_dir: directory storing the t2w images and labels. + pixdim: output voxel spacing. if providing a single number, will use it for the first dimension. + items of the pixdim sequence map to the spatial dimensions of input image, if length + of pixdim sequence is longer than image spatial dimensions, will ignore the longer part, + if shorter, will pad with `1.0`. + if the components of the `pixdim` are non-positive values, the transform will use the + corresponding components of the original pixdim, which is computed from the `affine` + matrix of input image. + spatial_size: output image spatial size. + if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, + the transform will use the spatial size of `image`. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of image size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of image is `64`. + augmentation: bool, specifying apply augmentation or not. + rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then + `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter + for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used. + This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be + in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]` + for dim0 and nothing for the remaining dimensions. + translate_range: translate range with format matching `rotate_range`, it defines the range to randomly + select pixel/voxel to translate for every spatial dims. + scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select + the scale factor to translate for every spatial dims. A value of 1.0 is added to the result. + This allows 0 to correspond to no change (i.e., a scaling of 1.0). + """ + self.dataset_dir = dataset_dir + self.spatial_size = spatial_size + self.transform = get_transform( + pixdim=pixdim, + spatial_size=spatial_size, + augmentation=augmentation, + rotate_range=rotate_range, + translate_range=translate_range, + scale_range=scale_range + ) + + def __call__(self, patient_ins_tuple): + patient_name, ins = patient_ins_tuple + # print({ + # "image": f"{self.dataset_dir}/data/{patient_name}_img.nii", + # "label": f"{self.dataset_dir}/data/{patient_name}_mask.nii", + # "name": patient_name, + # }) + x = self.transform({ + "image": f"{self.dataset_dir}/data/{patient_name}_img.nii", + "label": f"{self.dataset_dir}/data/{patient_name}_mask.nii", + "name": patient_name, + }) + # print(x["image_meta_dict"]["affine"]) + # crop and resize foregrounds depth-wise + target_slice = torch.sum(x["label"], dim=(0, 1, 2)) != 0 + for k in ["image", "label"]: + x[k] = x[k][..., target_slice] + x[k] = F.interpolate( + x[k].unsqueeze(0).to(torch.float), + size=self.spatial_size, + mode="trilinear" if k == "image" else "nearest" + ).squeeze(0) + # print(x["image_meta_dict"]["affine"]) + # print(x["image_meta_dict"]["original_affine"]) + # print(x["image_meta_dict"]["pixdim"]) + # print(x["label_meta_dict"]["pixdim"]) + return x + + +if __name__ == '__main__': + from matplotlib import pyplot as plt + transform = get_transform( + augmentation=False, + spatial_size=[256, 256, 40], + pixdim=[0.75, 0.75, 2.5], + rotate_range=np.pi / 36, + translate_range=[20, 20, 4], + scale_range=[0.15, 0.15, 0.15] + ) + dataset_dir = "/Users/yiwenli/data/multiorgan_final" + patient_name = "005082" + img = transform({ + "image": f"{dataset_dir}/data/{patient_name}_img.nii", + "label": f"{dataset_dir}/data/{patient_name}_mask.nii", + "name": patient_name, + }) + print(img["image"].shape) + + plt.subplot(2, 2, 1) + plt.gca().set_title("img") + plt.gca().axis('off') + img_slice = img["image"][0, ..., 20] + plt.imshow(np.transpose(img_slice)) + + spacingd = Spacingd(keys=["image"], pixdim=[5, 5, 2.5]) + transformed_img = spacingd(img) + plt.subplot(2, 2, 2) + plt.gca().set_title("spaced_img") + plt.gca().axis('off') + print(transformed_img["image"].shape) + transformed_slice = transformed_img["image"][0, ..., 20] + plt.imshow(np.transpose(transformed_slice)) + plt.show() diff --git a/models/lower_pelvic_reg/scripts/dice_metric.py b/models/lower_pelvic_reg/scripts/dice_metric.py new file mode 100644 index 00000000..83fce88a --- /dev/null +++ b/models/lower_pelvic_reg/scripts/dice_metric.py @@ -0,0 +1,37 @@ +from ignite.metrics import Metric + + +class Dice(Metric): + + def __init__(self, ignored_class, output_transform=lambda x: x, device="cpu"): + self.ignored_class = ignored_class + self._num_correct = None + self._num_examples = None + super(Dice, self).__init__(output_transform=output_transform, device=device) + + @reinit__is_reduced + def reset(self): + self._num_correct = torch.tensor(0, device=self._device) + self._num_examples = 0 + super(CustomAccuracy, self).reset() + + @reinit__is_reduced + def update(self, output): + y_pred, y = output[0].detach(), output[1].detach() + + indices = torch.argmax(y_pred, dim=1) + + mask = (y != self.ignored_class) + mask &= (indices != self.ignored_class) + y = y[mask] + indices = indices[mask] + correct = torch.eq(indices, y).view(-1) + + self._num_correct += torch.sum(correct).to(self._device) + self._num_examples += correct.shape[0] + + @sync_all_reduce("_num_examples", "_num_correct:SUM") + def compute(self): + if self._num_examples == 0: + raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.') + return self._num_correct.item() / self._num_examples diff --git a/models/lower_pelvic_reg/scripts/evaluator.py b/models/lower_pelvic_reg/scripts/evaluator.py new file mode 100644 index 00000000..836d92d0 --- /dev/null +++ b/models/lower_pelvic_reg/scripts/evaluator.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, Tuple, Dict, Optional, Union + +import torch +from monai.data import MetaObj, MetaTensor +from monai.engines import Evaluator +from torch.utils.data import DataLoader + +from monai.config import IgniteInfo, KeysCollection +from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch +from monai.engines.workflow import Workflow +from monai.inferers import Inferer, SimpleInferer +from monai.networks.utils import eval_mode, train_mode +from monai.transforms import Transform +from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import +from monai.utils.enums import CommonKeys as Keys +from monai.utils.module import look_up_option + +if TYPE_CHECKING: + from ignite.engine import Engine, EventEnum + from ignite.metrics import Metric +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") + EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") + + +def prepare_reg_batch( + batchdata: Tuple[Dict], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + **kwargs, +) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + """ + Default function to prepare the data for current iteration. + + The input `batchdata` is a pair of dictionaries both with keys "image" and "label". + All returned tensors are moved to the given device using the given non-blocking argument before being returned. + + This function implements the expected API for a `prepare_batch` callable in Ignite: + https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html + + Args: + batchdata: a pair of dictionaries both with keys "image" and "label" + device: device to move every returned tensor to + non_blocking: equivalent argument for `Tensor.to` + kwargs: further arguments for `Tensor.to` + + Returns: + moving, fixed: a pair of dictionaries both with keys "image" and "label". + """ + moving, fixed = batchdata + for k in ["image", "label"]: + moving[k].to(device=device, non_blocking=non_blocking, **kwargs) + fixed[k].to(device=device, non_blocking=non_blocking, **kwargs) + return moving, fixed + + +class RegistrationEvaluator(Evaluator): + """ + Standard registration evaluation method with moving and fixed images and labels(optional), + inherits from evaluator and Workflow. + + Args: + device: an object representing the device on which to run. + val_data_loader: Ignite engine use data_loader to run, must be Iterable. Each batch input should be a pair of + dictionaries both with keys "image" and "label". + network: network to evaluate in the evaluator, should be regular PyTorch `torch.nn.Module`. + epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. + non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously + with respect to the host. For other cases, this argument has no effect. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. + iteration_update: the callable function for every iteration, expect to accept `engine` + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. + inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. + postprocessing: execute additional transformation for the model output data. + Typically, several Tensor based transforms composed by `Compose`. + key_val_metric: compute metric when every iteration completed, and save average value to + engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the + checkpoint into files. + additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value, + it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update + `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`. + val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: + CheckpointHandler, StatsHandler, etc. + amp: whether to enable auto-mixed-precision evaluation, default is False. + mode: model forward mode during evaluation, should be 'eval' or 'train', + which maps to `model.eval()` or `model.train()`, default to 'eval'. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html + #ignite.engine.engine.Engine.register_events. + decollate: whether to decollate the batch-first data to a list of data after model computation, + recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. + default to `True`. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + + """ + + def __init__( + self, + device: torch.device, + val_data_loader: Iterable | DataLoader, + network: torch.nn.Module, + epoch_length: int | None = None, + non_blocking: bool = False, + prepare_batch: Callable = prepare_reg_batch, + iteration_update: Callable[[Engine, Any], Any] | None = None, + inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_val_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, + val_handlers: Sequence | None = None, + amp: bool = False, + mode: ForwardMode | str = ForwardMode.EVAL, + event_names: list[str | EventEnum] | None = None, + event_to_attr: dict | None = None, + decollate: bool = True, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, + ) -> None: + super().__init__( + device=device, + val_data_loader=val_data_loader, + epoch_length=epoch_length, + non_blocking=non_blocking, + prepare_batch=prepare_batch, + iteration_update=iteration_update, + postprocessing=postprocessing, + key_val_metric=key_val_metric, + additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, + val_handlers=val_handlers, + amp=amp, + mode=mode, + event_names=event_names, + event_to_attr=event_to_attr, + decollate=decollate, + to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, + ) + + self.network = network + self.inferer = SimpleInferer() if inferer is None else inferer + + def _iteration(self, engine, batchdata): + """ + callback function for the Supervised Registration Evaluation processing logic of 1 iteration in Ignite Engine. + Return below items in a dictionary: + - moving_image: image Tensor data for model input, already moved to device. + - moving_label: label Tensor data corresponding to the image, already moved to device. + - fixed_image: image Tensor data for model input, already moved to device. + - fixed_label: label Tensor data corresponding to the image, already moved to device. + - ddf: dense displacement field which registers the moving towards fixed. + - warped_image: moving image warped by the predicted ddf + - warped_label: moving label warped by the predicted ddf + + Args: + engine: `SupervisedEvaluator` to execute operation for an iteration, should be a pair of dictionaries both + with keys "image" and "label". + batchdata: input data for this iteration. + + Raises: + ValueError: When ``batchdata`` is None. + + """ + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + moving, fixed = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) + + # put iteration outputs into engine.state + engine.state.output = { + "moving_image": moving["image"], + "moving_label": moving["label"], + "fixed_image": fixed["image"], + "fixed_label": fixed["label"], + "moving_name": moving["name"], + "fixed_name": fixed["name"] + } + + # execute forward computation + with engine.mode(engine.network): + if engine.amp: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + engine.state.output.update( + engine.inferer((moving, fixed), engine.network) + ) + else: + engine.state.output.update( + engine.inferer((moving, fixed), engine.network) + ) + for k, v in engine.state.output.items(): + if isinstance(v, MetaTensor): + engine.state.output[k] = torch.tensor(v.get_array()) + engine.state.batch = None + engine.fire_event(IterationEvents.FORWARD_COMPLETED) + engine.fire_event(IterationEvents.MODEL_COMPLETED) + return engine.state.output diff --git a/models/lower_pelvic_reg/scripts/inferer.py b/models/lower_pelvic_reg/scripts/inferer.py new file mode 100644 index 00000000..b950d26c --- /dev/null +++ b/models/lower_pelvic_reg/scripts/inferer.py @@ -0,0 +1,41 @@ +from typing import Dict, Callable, Any, Tuple + +import torch +from monai.inferers import Inferer +from monai.networks import one_hot +from monai.networks.blocks import Warp + + +class RegistrationInferer(Inferer): + def __init__(self) -> None: + Inferer.__init__(self) + self.warp = Warp() + + def __call__(self, input: Tuple[Dict, Dict], network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any): + """Unified callable function API of Inferers. + + Args: + input: a pair of dictionaries specifying moving and fixed, both dictionaries are expected to have keys + "image" and "label" with values both of shape (B, C, ...) + network: target model to execute inference. + supports callables such as ``lambda x: my_torch_model(x, additional_config)`` + args: optional args to be passed to ``network``. + kwargs: optional keyword args to be passed to ``network``. + """ + moving, fixed = input + ddf = network( + torch.cat([moving["image"], fixed["image"]], dim=1) + ) + warped_image = self.warp(image=moving["image"], ddf=ddf) + moving_label_onehot = one_hot( + moving["label"], + num_classes=int(max(torch.unique(moving["label"])) + 1) + ) + warped_label_onehot = self.warp(image=moving_label_onehot, ddf=ddf) + warped_label = torch.argmax(warped_label_onehot, dim=1) + output = { + "ddf": ddf, + "warped_image": warped_image, + "warped_label": warped_label + } + return output diff --git a/models/lower_pelvic_reg/scripts/visualise.py b/models/lower_pelvic_reg/scripts/visualise.py new file mode 100644 index 00000000..1ee37408 --- /dev/null +++ b/models/lower_pelvic_reg/scripts/visualise.py @@ -0,0 +1,81 @@ +import numpy as np +from pathlib import Path +from typing import Optional, Type, Union, Sequence +import nibabel as nib +import torch + +from monai.config import DtypeLike, KeysCollection +from monai.data import image_writer +from monai.transforms.io.array import SaveImage +from monai.transforms.transform import MapTransform +from monai.utils import GridSamplePadMode, ensure_tuple_rep +from monai.utils.enums import PostFix + +DEFAULT_POST_FIX = PostFix.meta() + + +class SaveRegd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.SaveImage`. + + Note: + Image should be channel-first shape: [C,H,W,[D]]. + If the data is a patch of an image, the patch index will be appended to the filename. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + pixdim: output voxel spacing. if providing a single number, will use it for the first dimension. + items of the pixdim sequence map to the spatial dimensions of input image, if length + of pixdim sequence is longer than image spatial dimensions, will ignore the longer part, + if shorter, will pad with `1.0`. + if the components of the `pixdim` are non-positive values, the transform will use the + corresponding components of the original pixdim, which is computed from the `affine` + matrix of input image. + spatial_size: output image spatial size. + if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, + the transform will use the spatial size of `image`. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of image size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of image is `64`. + output_dir: output image directory. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + pixdim: Sequence[float], + spatial_size: Optional[Union[Sequence[int], int]] = None, + output_dir: Union[Path, str] = "./", + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.affine = torch.diag(torch.tensor([*pixdim, 1])) + self.spatial_size = spatial_size + self.output_dir = output_dir + + def __call__(self, data): + """ + A dictionary with following items: + - moving_image: image Tensor data for model input, already moved to device. + - moving_label: label Tensor data corresponding to the image, already moved to device. + - fixed_image: image Tensor data for model input, already moved to device. + - fixed_label: label Tensor data corresponding to the image, already moved to device. + - ddf: dense displacement field which registers the moving towards fixed. + - warped_image: moving image warped by the predicted ddf + - warped_label: moving label warped by the predicted ddf + """ + for k in self.keys: + print(f"{k}: {torch.tensor(data[k]).shape}") + for key in self.keys: + print(type(data[key])) + print(torch.tensor(data[key]).shape) + img = nib.Nifti1Image( + torch.tensor(data[key]).reshape(*self.spatial_size).detach().cpu().numpy().astype(dtype=np.float32), + affine=self.affine + ) + name = data["moving_name"] + "_" + data["fixed_name"] + print(name) + nib.save(img, f"{self.output_dir}/{name}_{key}.nii") diff --git a/models/lower_pelvic_reg/test/vis.py b/models/lower_pelvic_reg/test/vis.py new file mode 100644 index 00000000..5556a9ec --- /dev/null +++ b/models/lower_pelvic_reg/test/vis.py @@ -0,0 +1,17 @@ +import numpy as np +import matplotlib.pyplot as plt + + +def vis(moving, fixed): + axs = plt.figure(constrained_layout=True).subplots(2, 2, sharex=True, sharey=True) + middle_index = moving["t2w"].shape[-1] // 2 + vis_dict = { + "moving_t2w": (axs[0, 0], moving["t2w"][0, ..., middle_index]), + "fixed_t2w": (axs[0, 1], fixed["t2w"][0, ..., middle_index]), + "moving_seg": (axs[1, 0], moving["seg"][0, ..., middle_index]), + "fixed_seg": (axs[1, 1], fixed["seg"][0, ..., middle_index]), + } + for title, (ax, img) in vis_dict.items(): + ax.set(title=title, aspect=1, xticks=[], yticks=[]) + ax.matshow(np.array(img)) + plt.show() diff --git a/models/mednist_reg/configs/inference.yaml b/models/mednist_reg/configs/inference.yaml index 0d18ae48..b20fda66 100644 --- a/models/mednist_reg/configs/inference.yaml +++ b/models/mednist_reg/configs/inference.yaml @@ -113,12 +113,11 @@ evaluator: keys: [m_img] resample: False output_dir: "@output_dir" - output_ext: "png" + output_ext: "nii" output_postfix: "moving" - output_dtype: "$np.uint8" + output_dtype: "$np.float32" scale: 255 separate_folder: False - writer: "PILWriter" output_name_formatter: "$lambda x, s: dict(idx=s._data_index, subject=x['filename_or_obj'])" - _target_: SaveImaged keys: [label]