Skip to content
This repository was archived by the owner on Jun 22, 2022. It is now read-only.

Commit b496220

Browse files
author
minerva-ml
committed
fixed loss
1 parent ceb7d16 commit b496220

File tree

3 files changed

+39
-22
lines changed

3 files changed

+39
-22
lines changed

common_blocks/models.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import torch
33
import torch.nn as nn
4+
from torch.nn import functional as F
45
import torch.optim as optim
56
from toolkit.pytorch_transformers.models import Model
67
from torch.autograd import Variable
@@ -167,10 +168,11 @@ def set_loss(self):
167168
if self.activation_func == 'softmax':
168169
raise NotImplementedError('No softmax loss defined')
169170
elif self.activation_func == 'sigmoid':
170-
loss_function = lovasz_loss
171-
# loss_function = DiceLoss()
172-
# loss_function = FocalWithLogitsLoss()
171+
loss_function = weighted_sum_loss
173172
# loss_function = nn.BCEWithLogitsLoss()
173+
# loss_function = DiceWithLogitsLoss()
174+
# loss_function = lovasz_loss
175+
# loss_function = FocalWithLogitsLoss()
174176
else:
175177
raise Exception('Only softmax and sigmoid activations are allowed')
176178
self.loss_function = [('mask', loss_function, 1.0)]
@@ -191,34 +193,49 @@ def load(self, filepath):
191193

192194

193195
class FocalWithLogitsLoss(nn.Module):
194-
def __init__(self, alpha=1.0, gamma=1.0):
196+
def __init__(self, alpha=1.0, gamma=1.0, reduction='elementwise_mean'):
195197
super().__init__()
196198
self.alpha = alpha
197199
self.gamma = gamma
200+
self.reduction = reduction
198201

199-
def forward(self, input, target):
200-
if not (target.size() == input.size()):
201-
raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
202+
def forward(self, output, target):
203+
if not (target.size() == output.size()):
204+
raise ValueError(
205+
"Target size ({}) must be the same as input size ({})".format(target.size(), output.size()))
202206

203-
max_val = (-input).clamp(min=0)
204-
logpt = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
207+
max_val = (-output).clamp(min=0)
208+
logpt = output - output * target + max_val + ((-max_val).exp() + (-output - max_val).exp()).log()
205209
pt = torch.exp(-logpt)
206210
at = self.alpha * target + (1 - target)
207211
loss = at * ((1 - pt).pow(self.gamma)) * logpt
208-
return loss
209212

213+
if self.reduction == 'none':
214+
return loss
215+
elif self.reduction == 'elementwise_mean':
216+
return loss.mean()
217+
else:
218+
return loss.sum()
210219

211-
class DiceLoss(nn.Module):
220+
221+
class DiceWithLogitsLoss(nn.Module):
212222
def __init__(self, smooth=0, eps=1e-7):
213223
super().__init__()
214224
self.smooth = smooth
215225
self.eps = eps
216226

217227
def forward(self, output, target):
228+
output = F.sigmoid(output)
218229
return 1 - (2 * torch.sum(output * target) + self.smooth) / (
219230
torch.sum(output) + torch.sum(target) + self.smooth + self.eps)
220231

221232

233+
def weighted_sum_loss(output, target):
234+
bce = nn.BCEWithLogitsLoss()(output, target)
235+
dice = DiceWithLogitsLoss()(output, target)
236+
return bce + 0.25 * dice
237+
238+
222239
def lovasz_loss(output, target):
223240
target = target.long()
224241
return lovasz_hinge(output, target)
@@ -246,6 +263,6 @@ def callbacks_network(callbacks_config):
246263
init_lr_finder = cbk.InitialLearningRateFinder()
247264
return cbk.CallbackList(
248265
callbacks=[experiment_timing, training_monitor, validation_monitor,
249-
model_checkpoints, lr_scheduler, neptune_monitor, early_stopping,
250-
# init_lr_finder
266+
model_checkpoints, neptune_monitor, early_stopping,
267+
lr_scheduler, #init_lr_finder,
251268
])

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
CLONE_EXPERIMENT_DIR_FROM = '' # When running eval in the cloud specify this as for example /input/SAL-14/output/experiment
3131
OVERWRITE_EXPERIMENT_DIR = False
3232
DEV_MODE = False
33-
USE_TTA = True
33+
USE_TTA = False
3434

3535
if OVERWRITE_EXPERIMENT_DIR and os.path.isdir(EXPERIMENT_DIR):
3636
shutil.rmtree(EXPERIMENT_DIR)

neptune.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#project: USERNAME/PROJECT
22

33
name: airbus_ships_challenge
4-
tags: [solution-1]
4+
tags: [solution-2]
55

66
metric:
77
channel: 'f2'
@@ -25,7 +25,7 @@ parameters:
2525
# Data Paths
2626
train_images_dir: /public/challenges/kaggle-ship-detection/train
2727
test_images_dir: /public/challenges/kaggle-ship-detection/test
28-
metadata_filepath: /outputs/ships_metadata.csv
28+
metadata_filepath: /output/metadata.csv
2929
annotation_file: /public/challenges/kaggle-ship-detection/train_ship_segmentations.csv
3030
masks_overlayed_dir: /output/masks_overlayed
3131

@@ -38,17 +38,17 @@ parameters:
3838
resize_target_size: 256
3939
pad_method: symmetric
4040
target_format: 'joblib'
41-
dev_mode_size: 500
41+
dev_mode_size: 50
4242

4343
# General parameters
4444
image_h: 256
4545
image_w: 256
4646
image_channels: 3
47-
training_sampler_size: 2000
47+
training_sampler_size: 2500
4848
training_sampler_empty_fraction: 0.0
4949
evaluation_size: 10000
5050
evaluation_empty_fraction: 0.52
51-
in_train_evaluation_size: 1000
51+
in_train_evaluation_size: 500
5252
fine_tuning: 1
5353

5454
# Network parameters
@@ -57,13 +57,13 @@ parameters:
5757
architecture: UNetSeResNetXt
5858

5959
# Training schedule
60-
epochs_nr: 10
60+
epochs_nr: 1000
6161
batch_size_train: 4
6262
batch_size_inference: 4
63-
lr: 0.0007
63+
lr: 0.0006
6464
momentum: 0.9
6565
gamma: 0.95
66-
patience: 10
66+
patience: 20
6767
validation_metric_name: 'f2'
6868
minimize_validation_metric: 0
6969
reduce_factor: 0.5

0 commit comments

Comments
 (0)