Skip to content

Commit 7cfa2c9

Browse files
authored
Add sample_std parameter to RandGaussianNoise. (#7492)
Fixes issue #7425 ### Description Add a `sample_std` parameter to `RandGaussianNoise` and `RandGaussianNoised`. When True, the Gaussian's standard deviation is sampled uniformly from 0 to std (i.e., what is currently done). When False, the noise's standard deviation is non-random and set to std. The default for sample_std would be True for backwards compatibility. Changes were based on RandRicianNoise which already has a `sample_std` parameter and is similar to RandGaussianNoise in concept and implementation. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Timothy Baker <bakertim@umich.edu>
1 parent 20512d3 commit 7cfa2c9

File tree

4 files changed

+32
-15
lines changed

4 files changed

+32
-15
lines changed

monai/transforms/intensity/array.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,24 +91,33 @@ class RandGaussianNoise(RandomizableTransform):
9191
mean: Mean or “centre” of the distribution.
9292
std: Standard deviation (spread) of distribution.
9393
dtype: output data type, if None, same as input image. defaults to float32.
94+
sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std.
9495
9596
"""
9697

9798
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
9899

99-
def __init__(self, prob: float = 0.1, mean: float = 0.0, std: float = 0.1, dtype: DtypeLike = np.float32) -> None:
100+
def __init__(
101+
self,
102+
prob: float = 0.1,
103+
mean: float = 0.0,
104+
std: float = 0.1,
105+
dtype: DtypeLike = np.float32,
106+
sample_std: bool = True,
107+
) -> None:
100108
RandomizableTransform.__init__(self, prob)
101109
self.mean = mean
102110
self.std = std
103111
self.dtype = dtype
104112
self.noise: np.ndarray | None = None
113+
self.sample_std = sample_std
105114

106115
def randomize(self, img: NdarrayOrTensor, mean: float | None = None) -> None:
107116
super().randomize(None)
108117
if not self._do_transform:
109118
return None
110-
rand_std = self.R.uniform(0, self.std)
111-
noise = self.R.normal(self.mean if mean is None else mean, rand_std, size=img.shape)
119+
std = self.R.uniform(0, self.std) if self.sample_std else self.std
120+
noise = self.R.normal(self.mean if mean is None else mean, std, size=img.shape)
112121
# noise is float64 array, convert to the output dtype to save memory
113122
self.noise, *_ = convert_data_type(noise, dtype=self.dtype)
114123

monai/transforms/intensity/dictionary.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@
172172
class RandGaussianNoised(RandomizableTransform, MapTransform):
173173
"""
174174
Dictionary-based version :py:class:`monai.transforms.RandGaussianNoise`.
175-
Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if want to add
175+
Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if you want to add
176176
different noise for every field, please use this transform separately.
177177
178178
Args:
@@ -183,6 +183,7 @@ class RandGaussianNoised(RandomizableTransform, MapTransform):
183183
std: Standard deviation (spread) of distribution.
184184
dtype: output data type, if None, same as input image. defaults to float32.
185185
allow_missing_keys: don't raise exception if key is missing.
186+
sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std.
186187
"""
187188

188189
backend = RandGaussianNoise.backend
@@ -195,10 +196,11 @@ def __init__(
195196
std: float = 0.1,
196197
dtype: DtypeLike = np.float32,
197198
allow_missing_keys: bool = False,
199+
sample_std: bool = True,
198200
) -> None:
199201
MapTransform.__init__(self, keys, allow_missing_keys)
200202
RandomizableTransform.__init__(self, prob)
201-
self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype)
203+
self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype, sample_std=sample_std)
202204

203205
def set_random_state(
204206
self, seed: int | None = None, state: np.random.RandomState | None = None

tests/test_rand_gaussian_noise.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,24 @@
2222

2323
TESTS = []
2424
for p in TEST_NDARRAYS:
25-
TESTS.append(("test_zero_mean", p, 0, 0.1))
26-
TESTS.append(("test_non_zero_mean", p, 1, 0.5))
25+
TESTS.append(("test_zero_mean", p, 0, 0.1, True))
26+
TESTS.append(("test_non_zero_mean", p, 1, 0.5, True))
27+
TESTS.append(("test_no_sample_std", p, 1, 0.5, False))
2728

2829

2930
class TestRandGaussianNoise(NumpyImageTestCase2D):
3031

3132
@parameterized.expand(TESTS)
32-
def test_correct_results(self, _, im_type, mean, std):
33+
def test_correct_results(self, _, im_type, mean, std, sample_std):
3334
seed = 0
34-
gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std)
35+
gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std, sample_std=sample_std)
3536
gaussian_fn.set_random_state(seed)
3637
im = im_type(self.imt)
3738
noised = gaussian_fn(im)
3839
np.random.seed(seed)
3940
np.random.random()
40-
expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape)
41+
_std = np.random.uniform(0, std) if sample_std else std
42+
expected = self.imt + np.random.normal(mean, _std, size=self.imt.shape)
4143
if isinstance(noised, torch.Tensor):
4244
noised = noised.cpu()
4345
np.testing.assert_allclose(expected, noised, atol=1e-5)

tests/test_rand_gaussian_noised.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,28 @@
2222

2323
TESTS = []
2424
for p in TEST_NDARRAYS:
25-
TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1])
26-
TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5])
25+
TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1, True])
26+
TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5, True])
27+
TESTS.append(["test_no_sample_std", p, ["img1", "img2"], 1, 0.5, False])
2728

2829
seed = 0
2930

3031

3132
class TestRandGaussianNoised(NumpyImageTestCase2D):
3233

3334
@parameterized.expand(TESTS)
34-
def test_correct_results(self, _, im_type, keys, mean, std):
35-
gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64)
35+
def test_correct_results(self, _, im_type, keys, mean, std, sample_std):
36+
gaussian_fn = RandGaussianNoised(
37+
keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64, sample_std=sample_std
38+
)
3639
gaussian_fn.set_random_state(seed)
3740
im = im_type(self.imt)
3841
noised = gaussian_fn({k: im for k in keys})
3942
np.random.seed(seed)
4043
# simulate the randomize() of transform
4144
np.random.random()
42-
noise = np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape)
45+
_std = np.random.uniform(0, std) if sample_std else std
46+
noise = np.random.normal(mean, _std, size=self.imt.shape)
4347
for k in keys:
4448
expected = self.imt + noise
4549
if isinstance(noised[k], torch.Tensor):

0 commit comments

Comments
 (0)