|
33 | 33 |
|
34 | 34 |
|
35 | 35 | class PercetualNetworkType(StrEnum): |
36 | | - """Types of neural networks that are supported by perceptua loss.""" |
| 36 | + """Types of neural networks that are supported by perceptual loss.""" |
37 | 37 |
|
38 | 38 | alex = "alex" |
39 | 39 | vgg = "vgg" |
@@ -131,7 +131,7 @@ def __init__( |
131 | 131 | net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir |
132 | 132 | ) |
133 | 133 | elif "radimagenet_" in network_type: |
134 | | - self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) |
| 134 | + self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False, cache_dir=cache_dir) |
135 | 135 | elif network_type == "resnet50": |
136 | 136 | self.perceptual_function = TorchvisionModelPerceptualSimilarity( |
137 | 137 | net=network_type, |
@@ -226,7 +226,7 @@ class MedicalNetPerceptualSimilarity(nn.Module): |
226 | 226 |
|
227 | 227 | def __init__( |
228 | 228 | self, |
229 | | - net: str = "medicalnet_resnet_10_23datasets", |
| 229 | + net: str = "medicalnet_resnet10_23datasets", |
230 | 230 | verbose: bool = False, |
231 | 231 | channel_wise: bool = False, |
232 | 232 | cache_dir: str | None = None, |
@@ -333,7 +333,7 @@ def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cac |
333 | 333 | raise ValueError( |
334 | 334 | f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}." |
335 | 335 | ) |
336 | | - self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir, |
| 336 | + self.model = torch.hub.load("Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, |
337 | 337 | trust_repo=True) |
338 | 338 | self.eval() |
339 | 339 |
|
|
0 commit comments