Skip to content

Commit 9a45627

Browse files
Fix #8350: Clarify LocalNormalizedCrossCorrelationLoss docstring (#8639)
## Description This PR improves the docstring for `LocalNormalizedCrossCorrelationLoss` to address the ambiguities identified in #8350. ## Problem The current docstring does not clearly document: - The range of returned loss values - Whether the loss should be minimized or maximized - How to interpret high vs. low loss values ## Solution Enhanced the class docstring with comprehensive documentation including Returns section (value range, optimization direction) and Note section (implementation details, interpretation guidelines). ## Changes - Added `Returns` section with explicit value range and optimization direction - Added `Note` section explaining transformations and interpretation - Reorganized Args to class level for better discoverability - Followed MONAI formatting conventions ## Testing - [x] Verified docstring syntax is correct - [x] Confirmed technical accuracy by analyzing implementation - [x] Validated all three issue requirements are addressed Signed-off-by: Mohamed Salah <eng.mohamed.tawab@gmail.com>
1 parent 865b0e7 commit 9a45627

File tree

1 file changed

+30
-15
lines changed

1 file changed

+30
-15
lines changed

monai/losses/image_dissimilarity.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def make_gaussian_kernel(kernel_size: int) -> torch.Tensor:
5151
class LocalNormalizedCrossCorrelationLoss(_Loss):
5252
"""
5353
Local squared zero-normalized cross-correlation.
54+
5455
The loss is based on a moving kernel/window over the y_true/y_pred,
5556
within the window the square of zncc is calculated.
5657
The kernel can be a rectangular / triangular / gaussian window.
@@ -59,6 +60,35 @@ class LocalNormalizedCrossCorrelationLoss(_Loss):
5960
Adapted from:
6061
https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
6162
DeepReg (https://github.com/DeepRegNet/DeepReg)
63+
64+
Args:
65+
spatial_dims: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3.
66+
kernel_size: kernel spatial size, must be odd.
67+
kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``.
68+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
69+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
70+
71+
- ``"none"``: no reduction will be applied.
72+
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
73+
- ``"sum"``: the output will be summed.
74+
smooth_nr: a small constant added to the numerator to avoid nan.
75+
smooth_dr: a small constant added to the denominator to avoid nan.
76+
77+
Returns:
78+
torch.Tensor: The computed loss value. The output range is approximately [-1, 0], where:
79+
- Values closer to -1 indicate higher correlation (better match)
80+
- Values closer to 0 indicate lower correlation (worse match)
81+
- This loss should be **minimized** during optimization
82+
83+
Note:
84+
The implementation computes the squared normalized cross-correlation coefficient
85+
and then negates it, transforming the correlation maximization problem into a
86+
loss minimization problem suitable for standard PyTorch optimizers.
87+
88+
Interpretation:
89+
- Loss ≈ -1: Perfect correlation between images
90+
- Loss ≈ 0: No correlation between images
91+
- Lower (more negative) values indicate better alignment
6292
"""
6393

6494
def __init__(
@@ -70,21 +100,6 @@ def __init__(
70100
smooth_nr: float = 0.0,
71101
smooth_dr: float = 1e-5,
72102
) -> None:
73-
"""
74-
Args:
75-
spatial_dims: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3.
76-
kernel_size: kernel spatial size, must be odd.
77-
kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``.
78-
reduction: {``"none"``, ``"mean"``, ``"sum"``}
79-
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
80-
81-
- ``"none"``: no reduction will be applied.
82-
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
83-
- ``"sum"``: the output will be summed.
84-
smooth_nr: a small constant added to the numerator to avoid nan.
85-
smooth_dr: a small constant added to the denominator to avoid nan.
86-
87-
"""
88103
super().__init__(reduction=LossReduction(reduction).value)
89104

90105
self.ndim = spatial_dims

0 commit comments

Comments
 (0)