Skip to content

Commit b1e4a50

Browse files
author
Virginia Fernandez
committed
Perceptual loss changes.
1 parent 946cfdf commit b1e4a50

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

monai/losses/perceptual.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,16 @@
1818

1919
from monai.utils import optional_import
2020
from monai.utils.enums import StrEnum
21+
from huggingface_hub import hf_hub_download
2122

2223
LPIPS, _ = optional_import("lpips", name="LPIPS")
2324
torchvision, _ = optional_import("torchvision")
2425

2526

2627
class PercetualNetworkType(StrEnum):
28+
"""Types of neural networks that are supported by perceptua loss.
29+
"""
30+
2731
alex = "alex"
2832
vgg = "vgg"
2933
squeeze = "squeeze"
@@ -108,9 +112,12 @@ def __init__(
108112

109113
self.spatial_dims = spatial_dims
110114
self.perceptual_function: nn.Module
115+
116+
# If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used.
111117
if spatial_dims == 3 and is_fake_3d is False:
112118
self.perceptual_function = MedicalNetPerceptualSimilarity(
113-
net=network_type, verbose=False, channel_wise=channel_wise
119+
net=network_type, verbose=False, channel_wise=channel_wise,
120+
cache_dir=cache_dir
114121
)
115122
elif "radimagenet_" in network_type:
116123
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
@@ -122,7 +129,9 @@ def __init__(
122129
pretrained_state_dict_key=pretrained_state_dict_key,
123130
)
124131
else:
132+
# VGG, AlexNet and SqueezeNet are independently handled by LPIPS.
125133
self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False)
134+
126135
self.is_fake_3d = is_fake_3d
127136
self.fake_3d_ratio = fake_3d_ratio
128137
self.channel_wise = channel_wise
@@ -194,7 +203,7 @@ class MedicalNetPerceptualSimilarity(nn.Module):
194203
"""
195204
Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer
196205
Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from
197-
"Warvito/MedicalNet-models".
206+
"Project-MONAI/perceptual-models".
198207
199208
Args:
200209
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
@@ -205,11 +214,12 @@ class MedicalNetPerceptualSimilarity(nn.Module):
205214
"""
206215

207216
def __init__(
208-
self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False
217+
self, net: str = "medicalnet_resnet_10_23datasets", verbose: bool = False, channel_wise: bool = False,
218+
cache_dir: str | None = None,
209219
) -> None:
210220
super().__init__()
211221
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
212-
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose)
222+
self.model = torch.hub.load("Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir)
213223
self.eval()
214224

215225
self.channel_wise = channel_wise
@@ -287,17 +297,20 @@ class RadImageNetPerceptualSimilarity(nn.Module):
287297
"""
288298
Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et
289299
al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class
290-
uses torch Hub to download the networks from "Warvito/radimagenet-models".
300+
uses torch Hub to download the networks from "Project-MONAI/perceptual-models".
291301
292302
Args:
293303
net: {``"radimagenet_resnet50"``}
294304
Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``.
295305
verbose: if false, mute messages from torch Hub load function.
296306
"""
297307

298-
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None:
308+
def __init__(self, net: str = "radimagenet_resnet50",
309+
verbose: bool = False,
310+
cache_dir: str | None = None) -> None:
299311
super().__init__()
300-
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose)
312+
self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose,
313+
cache_dir=cache_dir)
301314
self.eval()
302315

303316
for param in self.parameters():

0 commit comments

Comments
 (0)