Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,5 @@ runs
*.pth

*zarr/*

monai-dev/
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ All notable changes to MONAI are documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
* Added `RandNonCentralChiNoise` and `RandNonCentralChiNoised` for generalized Rician noise simulation in MRI.

## [1.5.1] - 2025-09-22

Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
RandHistogramShift,
RandIntensityRemap,
RandKSpaceSpikeNoise,
RandNonCentralChiNoise,
RandRicianNoise,
RandScaleIntensity,
RandScaleIntensityFixedMean,
Expand Down Expand Up @@ -199,6 +200,9 @@
RandKSpaceSpikeNoised,
RandKSpaceSpikeNoiseD,
RandKSpaceSpikeNoiseDict,
RandNonCentralChiNoised,
RandNonCentralChiNoiseD,
RandNonCentralChiNoiseDict,
RandRicianNoised,
RandRicianNoiseD,
RandRicianNoiseDict,
Expand Down
200 changes: 162 additions & 38 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

__all__ = [
"RandGaussianNoise",
"RandNonCentralChiNoise",
"RandRicianNoise",
"ShiftIntensity",
"RandShiftIntensity",
Expand Down Expand Up @@ -140,6 +141,149 @@ def __call__(self, img: NdarrayOrTensor, mean: float | None = None, randomize: b
return img + noise


class RandNonCentralChiNoise(RandomizableTransform):
"""
Add non-central chi noise to an image.
This distribution is the square root of the sum of squares of k independent
Gaussian random variables, where one of the variables has a non-zero mean
(the signal).
This is a generalization of Rician noise. `degrees_of_freedom=2` is Rician noise.
See: https://en.wikipedia.org/wiki/Noncentral_chi_distribution and https://archive.ismrm.org/2024/3123_NZkvJdQat.html

Args:
prob: Probability to add noise.
mean: Mean or "centre" of the Gaussian noise distributions.
std: Standard deviation (spread) of the Gaussian noise distributions.
degrees_of_freedom: Number of Gaussian distributions (degrees of freedom).
`degrees_of_freedom=2` is Rician noise.
channel_wise: If True, treats each channel of the image separately.
relative: If True, the spread of the sampled Gaussian distributions will
be std times the standard deviation of the image or channel's intensity
histogram.
sample_std: If True, sample the spread of the Gaussian distributions
uniformly from 0 to std.
dtype: output data type, if None, same as input image. defaults to float32.

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
prob: float = 0.1,
mean: Sequence[float] | float = 0.0,
std: Sequence[float] | float = 1.0,
degrees_of_freedom: int = 64, # 64 default because typical modern brain MRI is 32 quadrature coils
channel_wise: bool = False,
relative: bool = False,
sample_std: bool = True,
dtype: DtypeLike = np.float32,
) -> None:
"""
Initializes the transform.

Args:
prob: Probability to add noise.
mean: Mean of the Gaussian noise distributions.
std: Standard deviation (spread) of the Gaussian noise distributions.
degrees_of_freedom: Number of Gaussian distributions (degrees of freedom).
`degrees_of_freedom=2` is Rician noise. Defaults to 64 (32 quadrature coils).
channel_wise: If True, treats each channel of the image separately.
relative: If True, the spread of the sampled Gaussian distributions will
be std times the standard deviation of the image or channel's intensity
histogram.
sample_std: If True, sample the spread of the Gaussian distributions
uniformly from 0 to std.
dtype: output data type, if None, same as input image. defaults to float32.

Raises:
ValueError: If `degrees_of_freedom` is not an integer or is less than 1.
"""
RandomizableTransform.__init__(self, prob)
self.prob = prob
self.mean = mean
self.std = std
if not isinstance(degrees_of_freedom, int) or degrees_of_freedom < 1:
raise ValueError("degrees_of_freedom must be an integer >= 1.")
self.degrees_of_freedom = degrees_of_freedom
self.channel_wise = channel_wise
self.relative = relative
self.sample_std = sample_std
self.dtype = dtype

def _add_noise(self, img: NdarrayOrTensor, mean: float, std: float, k: int):
"""
Applies non-central chi noise to a single image or channel.

This method generates `k` Gaussian noise arrays, adds the input `img`
to the first one (as the non-centrality component), and then computes
the square root of the sum of squares.

Args:
img: Input image array.
mean: Mean for the Gaussian noise distributions.
std: Standard deviation for the Gaussian noise distributions.
k: Degrees of freedom (number of noise arrays).

Returns:
Image with non-central chi noise applied, with the same
backend (Numpy/Torch) as the input.
"""
dtype_np = get_equivalent_dtype(img.dtype, np.ndarray)
im_shape = img.shape
_std = self.R.uniform(0, std) if self.sample_std else std

# Create a stack of k noise arrays
noise_shape = (k, *im_shape)
all_noises_np = self.R.normal(mean, _std, size=noise_shape).astype(dtype_np, copy=False)

if isinstance(img, torch.Tensor):
all_noises = torch.tensor(all_noises_np, device=img.device)
all_noises[0] = all_noises[0] + img
sum_sq = torch.sum(all_noises**2, dim=0)
return torch.sqrt(sum_sq)

all_noises_np[0] = all_noises_np[0] + img
sum_sq = np.sum(all_noises_np**2, axis=0)
return np.sqrt(sum_sq)

def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
src = img
img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=self.dtype)
if randomize:
super().randomize(None)

if not self._do_transform:
img, *_ = convert_to_dst_type(img, dst=src, dtype=self.dtype)
return img

if self.channel_wise:
_mean = ensure_tuple_rep(self.mean, len(img))
_std = ensure_tuple_rep(self.std, len(img))
for i, d in enumerate(img):
img[i] = self._add_noise(
d,
mean=_mean[i],
std=_std[i] * d.std() if self.relative else _std[i],
k=self.degrees_of_freedom,
)
Comment on lines +267 to +272
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix CUDA crash for relative channel noise.

Passing d.std() straight into np.random.uniform works on CPU but throws TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first. when the input lives on GPU (the common case for MONAI). Convert the statistic to a host-side Python float before handing it to the RNG.

-            for i, d in enumerate(img):
-                img[i] = self._add_noise(
-                    d,
-                    mean=_mean[i],
-                    std=_std[i] * d.std() if self.relative else _std[i],
-                    k=self.degrees_of_freedom,
-                )
+            for i, d in enumerate(img):
+                if self.relative:
+                    channel_std = (
+                        d.detach().std().cpu().item() if isinstance(d, torch.Tensor) else float(np.asarray(d).std())
+                    )
+                    std_arg = _std[i] * channel_std
+                else:
+                    std_arg = _std[i]
+                img[i] = self._add_noise(
+                    d,
+                    mean=_mean[i],
+                    std=std_arg,
+                    k=self.degrees_of_freedom,
+                )
🤖 Prompt for AI Agents
In monai/transforms/intensity/array.py around lines 267 to 272, the call that
passes d.std() into the RNG can crash on CUDA tensors; replace passing the
tensor directly with a host-side Python float by computing the
standard-deviation as a CPU scalar (e.g., use tensor.detach().cpu().item() or
float(...) for numpy arrays) and pass that float into the RNG so the RNG always
receives a Python number rather than a CUDA tensor.

else:
if not isinstance(self.mean, (int, float)):
raise RuntimeError(f"If channel_wise is False, mean must be a float or int, got {type(self.mean)}.")
if not isinstance(self.std, (int, float)):
raise RuntimeError(f"If channel_wise is False, std must be a float or int, got {type(self.std)}.")
std = self.std * img.std().item() if self.relative else self.std
if not isinstance(std, (int, float)):
raise RuntimeError(f"std must be a float or int number, got {type(std)}.")
img = self._add_noise(img, mean=self.mean, std=std, k=self.degrees_of_freedom)

img, *_ = convert_to_dst_type(img, dst=src, dtype=self.dtype)
return img


class RandRicianNoise(RandomizableTransform):
"""
Add Rician noise to image.
Expand Down Expand Up @@ -344,9 +488,7 @@ class StdShiftIntensity(Transform):

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self, factor: float, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32
) -> None:
def __init__(self, factor: float, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32) -> None:
self.factor = factor
self.nonzero = nonzero
self.channel_wise = channel_wise
Expand Down Expand Up @@ -436,9 +578,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
if not self._do_transform:
return img

shifter = StdShiftIntensity(
factor=self.factor, nonzero=self.nonzero, channel_wise=self.channel_wise, dtype=self.dtype
)
shifter = StdShiftIntensity(factor=self.factor, nonzero=self.nonzero, channel_wise=self.channel_wise, dtype=self.dtype)
return shifter(img=img)


Expand Down Expand Up @@ -1128,12 +1268,16 @@ def _clip(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
(
lower_percentile
if lower_percentile is None
else lower_percentile.item() if hasattr(lower_percentile, "item") else lower_percentile
else lower_percentile.item()
if hasattr(lower_percentile, "item")
else lower_percentile
),
(
upper_percentile
if upper_percentile is None
else upper_percentile.item() if hasattr(upper_percentile, "item") else upper_percentile
else upper_percentile.item()
if hasattr(upper_percentile, "item")
else upper_percentile
),
)
)
Expand Down Expand Up @@ -1257,9 +1401,7 @@ def __init__(

if isinstance(gamma, (int, float)):
if gamma <= 0.5:
raise ValueError(
f"if gamma is a number, must greater than 0.5 and value is picked from (0.5, gamma), got {gamma}"
)
raise ValueError(f"if gamma is a number, must greater than 0.5 and value is picked from (0.5, gamma), got {gamma}")
self.gamma = (0.5, gamma)
elif len(gamma) != 2:
raise ValueError("gamma should be a number or pair of numbers.")
Expand All @@ -1270,9 +1412,7 @@ def __init__(
self.invert_image: bool = invert_image
self.retain_stats: bool = retain_stats

self.adjust_contrast = AdjustContrast(
self.gamma_value, invert_image=self.invert_image, retain_stats=self.retain_stats
)
self.adjust_contrast = AdjustContrast(self.gamma_value, invert_image=self.invert_image, retain_stats=self.retain_stats)

def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
Expand Down Expand Up @@ -1398,9 +1538,7 @@ def _normalize(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
b_min = ((self.b_max - self.b_min) * (self.lower / 100.0)) + self.b_min
b_max = ((self.b_max - self.b_min) * (self.upper / 100.0)) + self.b_min

scalar = ScaleIntensityRange(
a_min=a_min, a_max=a_max, b_min=b_min, b_max=b_max, clip=self.clip, dtype=self.dtype
)
scalar = ScaleIntensityRange(a_min=a_min, a_max=a_max, b_min=b_min, b_max=b_max, clip=self.clip, dtype=self.dtype)
img = scalar(img)
img = convert_to_tensor(img, track_meta=False)
return img
Expand Down Expand Up @@ -1723,8 +1861,7 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32)

gf1, gf2 = (
GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device)
for sigma in (self.sigma1, self.sigma2)
GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device) for sigma in (self.sigma1, self.sigma2)
)
blurred_f = gf1(img_t.unsqueeze(0))
filter_blurred_f = gf2(blurred_f)
Expand Down Expand Up @@ -2082,9 +2219,7 @@ def __init__(self, loc: tuple | Sequence[tuple], k_intensity: Sequence[float] |
# assert one-to-one relationship between factors and locations
if isinstance(k_intensity, Sequence):
if not isinstance(loc[0], Sequence):
raise ValueError(
"If a sequence is passed to k_intensity, then a sequence of locations must be passed to loc"
)
raise ValueError("If a sequence is passed to k_intensity, then a sequence of locations must be passed to loc")
if len(k_intensity) != len(loc):
raise ValueError("There must be one intensity_factor value for each tuple of indices in loc.")
if isinstance(self.loc[0], Sequence) and k_intensity is not None and not isinstance(self.k_intensity, Sequence):
Expand Down Expand Up @@ -2424,9 +2559,7 @@ def __init__(
max_spatial_size: Sequence[int] | int | None = None,
prob: float = 0.1,
) -> None:
super().__init__(
holes=holes, spatial_size=spatial_size, max_holes=max_holes, max_spatial_size=max_spatial_size, prob=prob
)
super().__init__(holes=holes, spatial_size=spatial_size, max_holes=max_holes, max_spatial_size=max_spatial_size, prob=prob)
self.dropout_holes = dropout_holes
if isinstance(fill_value, (tuple, list)):
if len(fill_value) != 2:
Expand Down Expand Up @@ -2639,10 +2772,7 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor:
if self._do_transform:
if self.channel_wise:
img = torch.stack(
[
IntensityRemap(self.kernel_size, self.R.choice([-self.slope, self.slope]))(img[i])
for i in range(len(img))
]
[IntensityRemap(self.kernel_size, self.R.choice([-self.slope, self.slope]))(img[i]) for i in range(len(img))]
)
else:
img = IntensityRemap(self.kernel_size, self.R.choice([-self.slope, self.slope]))(img)
Expand Down Expand Up @@ -2698,9 +2828,7 @@ def __init__(

self.thresholds = {k: v for k, v in self.thresholds.items() if v is not None}
if self.thresholds.keys().isdisjoint(set("RGBHSV")):
raise ValueError(
f"Threshold for at least one channel of RGB or HSV needs to be set. {self.thresholds} is provided."
)
raise ValueError(f"Threshold for at least one channel of RGB or HSV needs to be set. {self.thresholds} is provided.")
self.invert = invert

def _set_threshold(self, threshold, mode):
Expand All @@ -2711,9 +2839,7 @@ def _set_threshold(self, threshold, mode):
elif isinstance(threshold, (float, int)):
self.thresholds[mode] = float(threshold)
else:
raise ValueError(
f"`threshold` should be either a callable, string, or float number, {type(threshold)} was given."
)
raise ValueError(f"`threshold` should be either a callable, string, or float number, {type(threshold)} was given.")

def _get_threshold(self, image, mode):
threshold = self.thresholds.get(mode)
Expand Down Expand Up @@ -2835,9 +2961,7 @@ def __init__(
raise ValueError(f"Unknown mode: {self.mode}. Supported modes are 'B' and 'RF'.")

if self.sink_mode not in ["all", "mid", "min", "mask"]:
raise ValueError(
f"Unknown sink mode: {self.sink_mode}. Supported modes are 'all', 'mid', 'min' and 'mask'."
)
raise ValueError(f"Unknown sink mode: {self.sink_mode}. Supported modes are 'all', 'mid', 'min' and 'mask'.")

self._compute_conf_map = UltrasoundConfidenceMap(
self.alpha, self.beta, self.gamma, self.mode, self.sink_mode, self.use_cg, self.cg_tol, self.cg_maxiter
Expand Down
Loading
Loading