11import numpy as np
22import torch
33import torch .nn as nn
4+ from torch .nn import functional as F
45import torch .optim as optim
56from toolkit .pytorch_transformers .models import Model
67from torch .autograd import Variable
@@ -167,10 +168,11 @@ def set_loss(self):
167168 if self .activation_func == 'softmax' :
168169 raise NotImplementedError ('No softmax loss defined' )
169170 elif self .activation_func == 'sigmoid' :
170- loss_function = lovasz_loss
171- # loss_function = DiceLoss()
172- # loss_function = FocalWithLogitsLoss()
171+ loss_function = weighted_sum_loss
173172 # loss_function = nn.BCEWithLogitsLoss()
173+ # loss_function = DiceWithLogitsLoss()
174+ # loss_function = lovasz_loss
175+ # loss_function = FocalWithLogitsLoss()
174176 else :
175177 raise Exception ('Only softmax and sigmoid activations are allowed' )
176178 self .loss_function = [('mask' , loss_function , 1.0 )]
@@ -191,34 +193,49 @@ def load(self, filepath):
191193
192194
193195class FocalWithLogitsLoss (nn .Module ):
194- def __init__ (self , alpha = 1.0 , gamma = 1.0 ):
196+ def __init__ (self , alpha = 1.0 , gamma = 1.0 , reduction = 'elementwise_mean' ):
195197 super ().__init__ ()
196198 self .alpha = alpha
197199 self .gamma = gamma
200+ self .reduction = reduction
198201
199- def forward (self , input , target ):
200- if not (target .size () == input .size ()):
201- raise ValueError ("Target size ({}) must be the same as input size ({})" .format (target .size (), input .size ()))
202+ def forward (self , output , target ):
203+ if not (target .size () == output .size ()):
204+ raise ValueError (
205+ "Target size ({}) must be the same as input size ({})" .format (target .size (), output .size ()))
202206
203- max_val = (- input ).clamp (min = 0 )
204- logpt = input - input * target + max_val + ((- max_val ).exp () + (- input - max_val ).exp ()).log ()
207+ max_val = (- output ).clamp (min = 0 )
208+ logpt = output - output * target + max_val + ((- max_val ).exp () + (- output - max_val ).exp ()).log ()
205209 pt = torch .exp (- logpt )
206210 at = self .alpha * target + (1 - target )
207211 loss = at * ((1 - pt ).pow (self .gamma )) * logpt
208- return loss
209212
213+ if self .reduction == 'none' :
214+ return loss
215+ elif self .reduction == 'elementwise_mean' :
216+ return loss .mean ()
217+ else :
218+ return loss .sum ()
210219
211- class DiceLoss (nn .Module ):
220+
221+ class DiceWithLogitsLoss (nn .Module ):
212222 def __init__ (self , smooth = 0 , eps = 1e-7 ):
213223 super ().__init__ ()
214224 self .smooth = smooth
215225 self .eps = eps
216226
217227 def forward (self , output , target ):
228+ output = F .sigmoid (output )
218229 return 1 - (2 * torch .sum (output * target ) + self .smooth ) / (
219230 torch .sum (output ) + torch .sum (target ) + self .smooth + self .eps )
220231
221232
233+ def weighted_sum_loss (output , target ):
234+ bce = nn .BCEWithLogitsLoss ()(output , target )
235+ dice = DiceWithLogitsLoss ()(output , target )
236+ return bce + 0.25 * dice
237+
238+
222239def lovasz_loss (output , target ):
223240 target = target .long ()
224241 return lovasz_hinge (output , target )
@@ -246,6 +263,6 @@ def callbacks_network(callbacks_config):
246263 init_lr_finder = cbk .InitialLearningRateFinder ()
247264 return cbk .CallbackList (
248265 callbacks = [experiment_timing , training_monitor , validation_monitor ,
249- model_checkpoints , lr_scheduler , neptune_monitor , early_stopping ,
250- # init_lr_finder
266+ model_checkpoints , neptune_monitor , early_stopping ,
267+ lr_scheduler , # init_lr_finder,
251268 ])
0 commit comments