@@ -202,9 +202,10 @@ def load_target(self, data_source, index, load_func):
202202
203203
204204class ImageSegmentationTTADataset (ImageSegmentationDataset ):
205- def __init__ (self , tta_params , * args , ** kwargs ):
205+ def __init__ (self , tta_params , tta_transform , * args , ** kwargs ):
206206 super ().__init__ (* args , ** kwargs )
207207 self .tta_params = tta_params
208+ self .tta_transform = tta_transform
208209
209210 def __getitem__ (self , index ):
210211 if self .image_source == 'memory' :
@@ -222,7 +223,7 @@ def __getitem__(self, index):
222223
223224 if self .tta_params is not None :
224225 tta_transform_specs = self .tta_params [index ]
225- Xi = test_time_augmentation_transform (Xi , tta_transform_specs )
226+ Xi = self . tta_transform (Xi , tta_transform_specs )
226227 Xi = to_pil (Xi )
227228
228229 if self .image_transform is not None :
@@ -320,6 +321,7 @@ def transform(self, X, tta_params, **kwargs):
320321
321322 def get_datagen (self , X , tta_params , loader_params ):
322323 dataset = self .dataset (tta_params = tta_params ,
324+ tta_transform = self .augmentation_params .tta_transform ,
323325 X = X ,
324326 y = None ,
325327 train_mode = False ,
@@ -369,8 +371,6 @@ def __init__(self, loader_params, dataset_params, augmentation_params):
369371 transforms .Normalize (mean = self .dataset_params .MEAN ,
370372 std = self .dataset_params .STD ),
371373 ])
372- self .mask_transform = transforms .Compose ([transforms .Lambda (preprocess_target ),
373- ])
374374
375375 self .image_augment_inference = ImgAug (self .augmentation_params ['image_augment_inference' ])
376376 self .image_augment_with_target_inference = ImgAug (
@@ -394,22 +394,18 @@ def transform(self, X, **kwargs):
394394 return {'X_tta' : X_tta , 'tta_params' : tta_params , 'img_ids' : img_ids }
395395
396396 def _get_tta_data (self , i , row ):
397- original_specs = {'ud_flip' : False , 'lr_flip' : False , 'rotation' : 0 , 'color_shift' : False }
397+ original_specs = {'ud_flip' : False , 'lr_flip' : False , 'rotation' : 0 }
398398 tta_specs = [original_specs ]
399399
400400 ud_options = [True , False ] if self .tta_transformations .flip_ud else [False ]
401401 lr_options = [True , False ] if self .tta_transformations .flip_lr else [False ]
402402 rot_options = [0 , 90 , 180 , 270 ] if self .tta_transformations .rotation else [0 ]
403- if self .tta_transformations .color_shift_runs :
404- color_shift_options = list (range (1 , self .tta_transformations .color_shift_runs + 1 , 1 ))
405- else :
406- color_shift_options = [False ]
407403
408- for ud , lr , rot , color in product (ud_options , lr_options , rot_options , color_shift_options ):
409- if ud is False and lr is False and rot == 0 and color is False :
404+ for ud , lr , rot in product (ud_options , lr_options , rot_options ):
405+ if ud is False and lr is False and rot == 0 is False :
410406 continue
411407 else :
412- tta_specs .append ({'ud_flip' : ud , 'lr_flip' : lr , 'rotation' : rot , 'color_shift' : color })
408+ tta_specs .append ({'ud_flip' : ud , 'lr_flip' : lr , 'rotation' : rot })
413409
414410 img_ids = [i ] * len (tta_specs )
415411 X_rows = [row ] * len (tta_specs )
@@ -431,30 +427,27 @@ def transform(self, X, **kwargs):
431427 return {'X_tta' : [X_tta ], 'tta_params' : tta_params , 'img_ids' : img_ids }
432428
433429 def _get_tta_data (self , i , row ):
434- original_specs = {'ud_flip' : False , 'lr_flip' : False , 'rotation' : 0 , 'color_shift' : False }
430+ original_specs = {'ud_flip' : False , 'lr_flip' : False , 'rotation' : 0 }
435431 tta_specs = [original_specs ]
436432
437433 ud_options = [True , False ] if self .tta_transformations .flip_ud else [False ]
438434 lr_options = [True , False ] if self .tta_transformations .flip_lr else [False ]
439435 rot_options = [0 , 90 , 180 , 270 ] if self .tta_transformations .rotation else [0 ]
440- if self .tta_transformations .color_shift_runs :
441- color_shift_options = list (range (1 , self .tta_transformations .color_shift_runs + 1 , 1 ))
442- else :
443- color_shift_options = [False ]
444436
445- for ud , lr , rot , color in product (ud_options , lr_options , rot_options , color_shift_options ):
446- if ud is False and lr is False and rot == 0 and color is False :
437+ for ud , lr , rot in product (ud_options , lr_options , rot_options ):
438+ if ud is False and lr is False and rot == 0 is False :
447439 continue
448440 else :
449- tta_specs .append ({'ud_flip' : ud , 'lr_flip' : lr , 'rotation' : rot , 'color_shift' : color })
441+ tta_specs .append ({'ud_flip' : ud , 'lr_flip' : lr , 'rotation' : rot })
450442
451443 img_ids = [i ] * len (tta_specs )
452444 X_rows = [row ] * len (tta_specs )
453445 return X_rows , tta_specs , img_ids
454446
455447
456448class TestTimeAugmentationAggregator (BaseTransformer ):
457- def __init__ (self , method , nthreads ):
449+ def __init__ (self , tta_inverse_transform , method , nthreads ):
450+ self .tta_inverse_transform = tta_inverse_transform
458451 self .method = method
459452 self .nthreads = nthreads
460453
@@ -471,6 +464,7 @@ def transform(self, images, tta_params, img_ids, **kwargs):
471464 _aggregate_augmentations = partial (aggregate_augmentations ,
472465 images = images ,
473466 tta_params = tta_params ,
467+ tta_inverse_transform = self .tta_inverse_transform ,
474468 img_ids = img_ids ,
475469 agg_method = self .agg_method )
476470 unique_img_ids = set (img_ids )
@@ -480,40 +474,18 @@ def transform(self, images, tta_params, img_ids, **kwargs):
480474 return {'aggregated_prediction' : averages_images }
481475
482476
483- def aggregate_augmentations (img_id , images , tta_params , img_ids , agg_method ):
477+ def aggregate_augmentations (img_id , images , tta_params , tta_inverse_transform , img_ids , agg_method ):
484478 tta_predictions_for_id = []
485479 for image , tta_param , ids in zip (images , tta_params , img_ids ):
486480 if ids == img_id :
487- tta_prediction = test_time_augmentation_inverse_transform (image , tta_param )
481+ tta_prediction = tta_inverse_transform (image , tta_param )
488482 tta_predictions_for_id .append (tta_prediction )
489483 else :
490484 continue
491485 tta_averaged = agg_method (np .stack (tta_predictions_for_id , axis = - 1 ))
492486 return tta_averaged
493487
494488
495- def test_time_augmentation_transform (image , tta_parameters ):
496- if tta_parameters ['ud_flip' ]:
497- image = np .flipud (image )
498- if tta_parameters ['lr_flip' ]:
499- image = np .fliplr (image )
500- if tta_parameters ['color_shift' ]:
501- random_color_shift = reseed (intensity_seq , deterministic = False )
502- image = random_color_shift .augment_image (image )
503- image = rotate (image , tta_parameters ['rotation' ])
504- return image
505-
506-
507- def test_time_augmentation_inverse_transform (image , tta_parameters ):
508- image = per_channel_rotation (image .copy (), - 1 * tta_parameters ['rotation' ])
509-
510- if tta_parameters ['lr_flip' ]:
511- image = per_channel_fliplr (image .copy ())
512- if tta_parameters ['ud_flip' ]:
513- image = per_channel_flipud (image .copy ())
514- return image
515-
516-
517489def per_channel_flipud (x ):
518490 x_ = x .copy ()
519491 for i , channel in enumerate (x ):
0 commit comments