Skip to content
This repository was archived by the owner on Jun 22, 2022. It is now read-only.

Commit ba37aed

Browse files
author
minerva-ml
committed
fixed initial pipe
1 parent 3f1ad37 commit ba37aed

File tree

12 files changed

+303
-198
lines changed

12 files changed

+303
-198
lines changed

common_blocks/architectures/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def __init__(self, in_channels, middle_channels, out_channels):
6868
self.conv1 = Conv2dBnRelu(in_channels, middle_channels)
6969
self.conv2 = Conv2dBnRelu(middle_channels, out_channels)
7070
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
71-
self.relu = nn.ReLU(inplace=True)
7271
self.channel_se = ChannelSELayer(out_channels, reduction=16)
7372
self.spatial_se = SpatialSELayer(out_channels)
7473

@@ -82,7 +81,7 @@ def forward(self, x, e=None):
8281
channel_se = self.channel_se(x)
8382
spatial_se = self.spatial_se(x)
8483

85-
x = self.relu(channel_se + spatial_se)
84+
x = channel_se + spatial_se
8685
return x
8786

8887

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from torch import nn
2+
from torch.nn import functional as F
3+
import torch
4+
5+
from .base import Conv2dBnRelu
6+
from .encoders import ResNetEncoders
7+
8+
9+
class PSPModule(nn.Module):
10+
def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)):
11+
super().__init__()
12+
self.stages = []
13+
self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes])
14+
self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1)
15+
self.relu = nn.ReLU()
16+
17+
def _make_stage(self, features, size):
18+
prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
19+
conv = nn.Conv2d(features, features, kernel_size=1, bias=False)
20+
return nn.Sequential(prior, conv)
21+
22+
def forward(self, feats):
23+
h, w = feats.size(2), feats.size(3)
24+
priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.stages] + [feats]
25+
bottle = self.bottleneck(torch.cat(priors, 1))
26+
return self.relu(bottle)
27+
28+
29+
class PSPUpsample(nn.Module):
30+
def __init__(self, in_channels, out_channels):
31+
super().__init__()
32+
self.conv = nn.Sequential(
33+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
34+
nn.BatchNorm2d(out_channels),
35+
nn.PReLU()
36+
)
37+
38+
def forward(self, x):
39+
p = F.upsample(input=x, scale_factor=2, mode='bilinear')
40+
return self.conv(p)
41+
42+
43+
class PSPNet(nn.Module):
44+
def __init__(self,
45+
encoder_depth,
46+
num_classes=2,
47+
sizes=(1, 2, 3, 6),
48+
deep_features_size=1024,
49+
dropout_2d=0.2,
50+
pretrained=False,
51+
use_hypercolumn=False,
52+
pool0=False):
53+
super().__init__()
54+
self.num_classes = num_classes
55+
self.dropout_2d = dropout_2d
56+
self.use_hypercolumn = use_hypercolumn
57+
58+
self.encoders = ResNetEncoders(encoder_depth, pretrained=pretrained, pool0=pool0)
59+
60+
if encoder_depth in [18, 34]:
61+
bottom_channel_nr = 512
62+
elif encoder_depth in [50, 101, 152]:
63+
bottom_channel_nr = 2048
64+
else:
65+
raise NotImplementedError('only 18, 34, 50, 101, 152 version of Resnet are implemented')
66+
67+
self.psp = PSPModule(bottom_channel_nr, deep_features_size, sizes)
68+
69+
self.up4 = PSPUpsample(deep_features_size, deep_features_size // 2)
70+
self.up3 = PSPUpsample(deep_features_size // 2, deep_features_size // 4)
71+
self.up2 = PSPUpsample(deep_features_size // 4, deep_features_size // 8)
72+
self.up1 = PSPUpsample(deep_features_size // 8, deep_features_size // 16)
73+
74+
if self.use_hypercolumn:
75+
self.final = nn.Sequential(Conv2dBnRelu(15 * bottom_channel_nr // 8, bottom_channel_nr // 8),
76+
nn.Conv2d(bottom_channel_nr // 8, num_classes, kernel_size=1, padding=0))
77+
else:
78+
self.final = nn.Sequential(Conv2dBnRelu(bottom_channel_nr // 8, bottom_channel_nr // 8),
79+
nn.Conv2d(bottom_channel_nr // 8, num_classes, kernel_size=1, padding=0))
80+
81+
def forward(self, x):
82+
encoder2, encoder3, encoder4, encoder5 = self.encoders(x)
83+
encoder5 = F.dropout2d(encoder5, p=self.dropout_2d)
84+
85+
psp = self.psp(encoder5)
86+
87+
up4 = self.up4(psp)
88+
up3 = self.up3(up4)
89+
up2 = self.up2(up3)
90+
up1 = self.up1(up2)
91+
if self.use_hypercolumn:
92+
hypercolumn = torch.cat([up1,
93+
F.upsample(up2, scale_factor=2, mode='bilinear'),
94+
F.upsample(up3, scale_factor=4, mode='bilinear'),
95+
F.upsample(up4, scale_factor=8, mode='bilinear'),
96+
], 1)
97+
drop = F.dropout2d(hypercolumn, p=self.dropout_2d)
98+
else:
99+
drop = F.dropout2d(up4, p=self.dropout_2d)
100+
return self.final(drop)

common_blocks/augmentation.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,6 @@ def test_time_augmentation_transform(image, tta_parameters):
110110
image = np.flipud(image)
111111
if tta_parameters['lr_flip']:
112112
image = np.fliplr(image)
113-
if tta_parameters['color_shift']:
114-
tta_intensity = reseed(tta_intensity_seq, deterministic=False)
115-
image = tta_intensity.augment_image(image)
116113
image = rotate(image, tta_parameters['rotation'])
117114
return image
118115

common_blocks/callbacks.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from common_blocks.utils.misc import get_logger, sigmoid, softmax, make_apply_transformer, get_list_of_image_predictions
1919
from common_blocks.utils.io import read_masks
2020
from .metrics import intersection_over_union_thresholds
21-
from .postprocessing import crop_image, resize_image, binarize, label
21+
from .postprocessing import crop_image, resize_image, binarize, label, masks_to_bounding_boxes
2222

2323
logger = get_logger()
2424

@@ -670,7 +670,16 @@ def postprocessing_pipeline_simplified(cache_dirpath, loader_mode):
670670
input_steps=[binarizer],
671671
adapter=Adapter({'images': E(binarizer.name, 'binarized_images'),
672672
}))
673-
labeler.set_mode_inference()
674-
labeler.set_parameters_upstream({'experiment_directory': cache_dirpath,
675-
'is_fittable': False})
676-
return labeler
673+
bounding_boxer = Step(name='bounding_boxer',
674+
transformer=make_apply_transformer(masks_to_bounding_boxes,
675+
output_name='labeled_images',
676+
apply_on=['images']),
677+
input_steps=[labeler],
678+
adapter=Adapter({'images': E(labeler.name, 'labeled_images'),
679+
}))
680+
681+
bounding_boxer.set_mode_inference()
682+
bounding_boxer.set_parameters_upstream({'experiment_directory': cache_dirpath,
683+
'is_fittable': False
684+
})
685+
return bounding_boxer

common_blocks/loaders.py

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,10 @@ def load_target(self, data_source, index, load_func):
202202

203203

204204
class ImageSegmentationTTADataset(ImageSegmentationDataset):
205-
def __init__(self, tta_params, *args, **kwargs):
205+
def __init__(self, tta_params, tta_transform, *args, **kwargs):
206206
super().__init__(*args, **kwargs)
207207
self.tta_params = tta_params
208+
self.tta_transform = tta_transform
208209

209210
def __getitem__(self, index):
210211
if self.image_source == 'memory':
@@ -222,7 +223,7 @@ def __getitem__(self, index):
222223

223224
if self.tta_params is not None:
224225
tta_transform_specs = self.tta_params[index]
225-
Xi = test_time_augmentation_transform(Xi, tta_transform_specs)
226+
Xi = self.tta_transform(Xi, tta_transform_specs)
226227
Xi = to_pil(Xi)
227228

228229
if self.image_transform is not None:
@@ -320,6 +321,7 @@ def transform(self, X, tta_params, **kwargs):
320321

321322
def get_datagen(self, X, tta_params, loader_params):
322323
dataset = self.dataset(tta_params=tta_params,
324+
tta_transform=self.augmentation_params.tta_transform,
323325
X=X,
324326
y=None,
325327
train_mode=False,
@@ -369,8 +371,6 @@ def __init__(self, loader_params, dataset_params, augmentation_params):
369371
transforms.Normalize(mean=self.dataset_params.MEAN,
370372
std=self.dataset_params.STD),
371373
])
372-
self.mask_transform = transforms.Compose([transforms.Lambda(preprocess_target),
373-
])
374374

375375
self.image_augment_inference = ImgAug(self.augmentation_params['image_augment_inference'])
376376
self.image_augment_with_target_inference = ImgAug(
@@ -394,22 +394,18 @@ def transform(self, X, **kwargs):
394394
return {'X_tta': X_tta, 'tta_params': tta_params, 'img_ids': img_ids}
395395

396396
def _get_tta_data(self, i, row):
397-
original_specs = {'ud_flip': False, 'lr_flip': False, 'rotation': 0, 'color_shift': False}
397+
original_specs = {'ud_flip': False, 'lr_flip': False, 'rotation': 0}
398398
tta_specs = [original_specs]
399399

400400
ud_options = [True, False] if self.tta_transformations.flip_ud else [False]
401401
lr_options = [True, False] if self.tta_transformations.flip_lr else [False]
402402
rot_options = [0, 90, 180, 270] if self.tta_transformations.rotation else [0]
403-
if self.tta_transformations.color_shift_runs:
404-
color_shift_options = list(range(1, self.tta_transformations.color_shift_runs + 1, 1))
405-
else:
406-
color_shift_options = [False]
407403

408-
for ud, lr, rot, color in product(ud_options, lr_options, rot_options, color_shift_options):
409-
if ud is False and lr is False and rot == 0 and color is False:
404+
for ud, lr, rot in product(ud_options, lr_options, rot_options):
405+
if ud is False and lr is False and rot == 0 is False:
410406
continue
411407
else:
412-
tta_specs.append({'ud_flip': ud, 'lr_flip': lr, 'rotation': rot, 'color_shift': color})
408+
tta_specs.append({'ud_flip': ud, 'lr_flip': lr, 'rotation': rot})
413409

414410
img_ids = [i] * len(tta_specs)
415411
X_rows = [row] * len(tta_specs)
@@ -431,30 +427,27 @@ def transform(self, X, **kwargs):
431427
return {'X_tta': [X_tta], 'tta_params': tta_params, 'img_ids': img_ids}
432428

433429
def _get_tta_data(self, i, row):
434-
original_specs = {'ud_flip': False, 'lr_flip': False, 'rotation': 0, 'color_shift': False}
430+
original_specs = {'ud_flip': False, 'lr_flip': False, 'rotation': 0}
435431
tta_specs = [original_specs]
436432

437433
ud_options = [True, False] if self.tta_transformations.flip_ud else [False]
438434
lr_options = [True, False] if self.tta_transformations.flip_lr else [False]
439435
rot_options = [0, 90, 180, 270] if self.tta_transformations.rotation else [0]
440-
if self.tta_transformations.color_shift_runs:
441-
color_shift_options = list(range(1, self.tta_transformations.color_shift_runs + 1, 1))
442-
else:
443-
color_shift_options = [False]
444436

445-
for ud, lr, rot, color in product(ud_options, lr_options, rot_options, color_shift_options):
446-
if ud is False and lr is False and rot == 0 and color is False:
437+
for ud, lr, rot in product(ud_options, lr_options, rot_options):
438+
if ud is False and lr is False and rot == 0 is False:
447439
continue
448440
else:
449-
tta_specs.append({'ud_flip': ud, 'lr_flip': lr, 'rotation': rot, 'color_shift': color})
441+
tta_specs.append({'ud_flip': ud, 'lr_flip': lr, 'rotation': rot})
450442

451443
img_ids = [i] * len(tta_specs)
452444
X_rows = [row] * len(tta_specs)
453445
return X_rows, tta_specs, img_ids
454446

455447

456448
class TestTimeAugmentationAggregator(BaseTransformer):
457-
def __init__(self, method, nthreads):
449+
def __init__(self, tta_inverse_transform, method, nthreads):
450+
self.tta_inverse_transform = tta_inverse_transform
458451
self.method = method
459452
self.nthreads = nthreads
460453

@@ -471,6 +464,7 @@ def transform(self, images, tta_params, img_ids, **kwargs):
471464
_aggregate_augmentations = partial(aggregate_augmentations,
472465
images=images,
473466
tta_params=tta_params,
467+
tta_inverse_transform=self.tta_inverse_transform,
474468
img_ids=img_ids,
475469
agg_method=self.agg_method)
476470
unique_img_ids = set(img_ids)
@@ -480,40 +474,18 @@ def transform(self, images, tta_params, img_ids, **kwargs):
480474
return {'aggregated_prediction': averages_images}
481475

482476

483-
def aggregate_augmentations(img_id, images, tta_params, img_ids, agg_method):
477+
def aggregate_augmentations(img_id, images, tta_params, tta_inverse_transform, img_ids, agg_method):
484478
tta_predictions_for_id = []
485479
for image, tta_param, ids in zip(images, tta_params, img_ids):
486480
if ids == img_id:
487-
tta_prediction = test_time_augmentation_inverse_transform(image, tta_param)
481+
tta_prediction = tta_inverse_transform(image, tta_param)
488482
tta_predictions_for_id.append(tta_prediction)
489483
else:
490484
continue
491485
tta_averaged = agg_method(np.stack(tta_predictions_for_id, axis=-1))
492486
return tta_averaged
493487

494488

495-
def test_time_augmentation_transform(image, tta_parameters):
496-
if tta_parameters['ud_flip']:
497-
image = np.flipud(image)
498-
if tta_parameters['lr_flip']:
499-
image = np.fliplr(image)
500-
if tta_parameters['color_shift']:
501-
random_color_shift = reseed(intensity_seq, deterministic=False)
502-
image = random_color_shift.augment_image(image)
503-
image = rotate(image, tta_parameters['rotation'])
504-
return image
505-
506-
507-
def test_time_augmentation_inverse_transform(image, tta_parameters):
508-
image = per_channel_rotation(image.copy(), -1 * tta_parameters['rotation'])
509-
510-
if tta_parameters['lr_flip']:
511-
image = per_channel_fliplr(image.copy())
512-
if tta_parameters['ud_flip']:
513-
image = per_channel_flipud(image.copy())
514-
return image
515-
516-
517489
def per_channel_flipud(x):
518490
x_ = x.copy()
519491
for i, channel in enumerate(x):

common_blocks/models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from toolkit.pytorch_transformers.models import Model
66
from torch.autograd import Variable
77

8-
from .architectures import unet, large_kernel_matters
8+
from .architectures import unet, large_kernel_matters, pspnet
99
from . import callbacks as cbk
1010
from .lovasz_losses import lovasz_hinge
1111
from common_blocks.utils.misc import sigmoid, softmax, get_list_of_image_predictions
@@ -36,6 +36,11 @@
3636
'dropout_2d': 0.0, 'use_relu': True, 'pool0': False
3737
},
3838
'init_weights': False},
39+
'PSPNet': {'model': pspnet.PSPNet,
40+
'model_config': {'encoder_depth': 34, 'pretrained': True,
41+
'use_hypercolumn': True, 'pool0': False
42+
},
43+
}
3944
}
4045

4146

@@ -164,6 +169,7 @@ def set_loss(self):
164169
elif self.activation_func == 'sigmoid':
165170
loss_function = lovasz_loss
166171
# loss_function = DiceLoss()
172+
# loss_function = FocalWithLogitsLoss()
167173
# loss_function = nn.BCEWithLogitsLoss()
168174
else:
169175
raise Exception('Only softmax and sigmoid activations are allowed')

0 commit comments

Comments
 (0)