1818
1919from monai .utils import optional_import
2020from monai .utils .enums import StrEnum
21+ from huggingface_hub import hf_hub_download
2122
2223LPIPS , _ = optional_import ("lpips" , name = "LPIPS" )
2324torchvision , _ = optional_import ("torchvision" )
2425
2526
2627class PercetualNetworkType (StrEnum ):
28+ """Types of neural networks that are supported by perceptua loss.
29+ """
30+
2731 alex = "alex"
2832 vgg = "vgg"
2933 squeeze = "squeeze"
@@ -108,9 +112,12 @@ def __init__(
108112
109113 self .spatial_dims = spatial_dims
110114 self .perceptual_function : nn .Module
115+
116+ # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used.
111117 if spatial_dims == 3 and is_fake_3d is False :
112118 self .perceptual_function = MedicalNetPerceptualSimilarity (
113- net = network_type , verbose = False , channel_wise = channel_wise
119+ net = network_type , verbose = False , channel_wise = channel_wise ,
120+ cache_dir = cache_dir
114121 )
115122 elif "radimagenet_" in network_type :
116123 self .perceptual_function = RadImageNetPerceptualSimilarity (net = network_type , verbose = False )
@@ -122,7 +129,9 @@ def __init__(
122129 pretrained_state_dict_key = pretrained_state_dict_key ,
123130 )
124131 else :
132+ # VGG, AlexNet and SqueezeNet are independently handled by LPIPS.
125133 self .perceptual_function = LPIPS (pretrained = pretrained , net = network_type , verbose = False )
134+
126135 self .is_fake_3d = is_fake_3d
127136 self .fake_3d_ratio = fake_3d_ratio
128137 self .channel_wise = channel_wise
@@ -194,7 +203,7 @@ class MedicalNetPerceptualSimilarity(nn.Module):
194203 """
195204 Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer
196205 Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from
197- "Warvito/MedicalNet -models".
206+ "Project-MONAI/perceptual -models".
198207
199208 Args:
200209 net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
@@ -205,11 +214,12 @@ class MedicalNetPerceptualSimilarity(nn.Module):
205214 """
206215
207216 def __init__ (
208- self , net : str = "medicalnet_resnet10_23datasets" , verbose : bool = False , channel_wise : bool = False
217+ self , net : str = "medicalnet_resnet_10_23datasets" , verbose : bool = False , channel_wise : bool = False ,
218+ cache_dir : str | None = None ,
209219 ) -> None :
210220 super ().__init__ ()
211221 torch .hub ._validate_not_a_forked_repo = lambda a , b , c : True
212- self .model = torch .hub .load ("warvito/MedicalNet -models" , model = net , verbose = verbose )
222+ self .model = torch .hub .load ("Project-MONAI/perceptual -models:main " , model = net , verbose = verbose , cache_dir = cache_dir )
213223 self .eval ()
214224
215225 self .channel_wise = channel_wise
@@ -287,17 +297,20 @@ class RadImageNetPerceptualSimilarity(nn.Module):
287297 """
288298 Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et
289299 al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class
290- uses torch Hub to download the networks from "Warvito/radimagenet -models".
300+ uses torch Hub to download the networks from "Project-MONAI/perceptual -models".
291301
292302 Args:
293303 net: {``"radimagenet_resnet50"``}
294304 Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``.
295305 verbose: if false, mute messages from torch Hub load function.
296306 """
297307
298- def __init__ (self , net : str = "radimagenet_resnet50" , verbose : bool = False ) -> None :
308+ def __init__ (self , net : str = "radimagenet_resnet50" ,
309+ verbose : bool = False ,
310+ cache_dir : str | None = None ) -> None :
299311 super ().__init__ ()
300- self .model = torch .hub .load ("Warvito/radimagenet-models" , model = net , verbose = verbose )
312+ self .model = torch .hub .load ("Project-MONAI/perceptual-models" , model = net , verbose = verbose ,
313+ cache_dir = cache_dir )
301314 self .eval ()
302315
303316 for param in self .parameters ():
0 commit comments