|
18 | 18 |
|
19 | 19 | from monai.utils import optional_import |
20 | 20 | from monai.utils.enums import StrEnum |
21 | | -from huggingface_hub import hf_hub_download |
22 | 21 |
|
23 | 22 | LPIPS, _ = optional_import("lpips", name="LPIPS") |
24 | 23 | torchvision, _ = optional_import("torchvision") |
25 | 24 |
|
26 | 25 |
|
27 | 26 | class PercetualNetworkType(StrEnum): |
28 | | - """Types of neural networks that are supported by perceptua loss. |
29 | | - """ |
| 27 | + """Types of neural networks that are supported by perceptua loss.""" |
30 | 28 |
|
31 | 29 | alex = "alex" |
32 | 30 | vgg = "vgg" |
@@ -116,8 +114,7 @@ def __init__( |
116 | 114 | # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used. |
117 | 115 | if spatial_dims == 3 and is_fake_3d is False: |
118 | 116 | self.perceptual_function = MedicalNetPerceptualSimilarity( |
119 | | - net=network_type, verbose=False, channel_wise=channel_wise, |
120 | | - cache_dir=cache_dir |
| 117 | + net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir |
121 | 118 | ) |
122 | 119 | elif "radimagenet_" in network_type: |
123 | 120 | self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) |
@@ -214,12 +211,17 @@ class MedicalNetPerceptualSimilarity(nn.Module): |
214 | 211 | """ |
215 | 212 |
|
216 | 213 | def __init__( |
217 | | - self, net: str = "medicalnet_resnet_10_23datasets", verbose: bool = False, channel_wise: bool = False, |
| 214 | + self, |
| 215 | + net: str = "medicalnet_resnet_10_23datasets", |
| 216 | + verbose: bool = False, |
| 217 | + channel_wise: bool = False, |
218 | 218 | cache_dir: str | None = None, |
219 | 219 | ) -> None: |
220 | 220 | super().__init__() |
221 | 221 | torch.hub._validate_not_a_forked_repo = lambda a, b, c: True |
222 | | - self.model = torch.hub.load("Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir) |
| 222 | + self.model = torch.hub.load( |
| 223 | + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir |
| 224 | + ) |
223 | 225 | self.eval() |
224 | 226 |
|
225 | 227 | self.channel_wise = channel_wise |
@@ -305,12 +307,9 @@ class RadImageNetPerceptualSimilarity(nn.Module): |
305 | 307 | verbose: if false, mute messages from torch Hub load function. |
306 | 308 | """ |
307 | 309 |
|
308 | | - def __init__(self, net: str = "radimagenet_resnet50", |
309 | | - verbose: bool = False, |
310 | | - cache_dir: str | None = None) -> None: |
| 310 | + def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: |
311 | 311 | super().__init__() |
312 | | - self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, |
313 | | - cache_dir=cache_dir) |
| 312 | + self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir) |
314 | 313 | self.eval() |
315 | 314 |
|
316 | 315 | for param in self.parameters(): |
|
0 commit comments