@@ -36,15 +36,19 @@ def forward(self, module_inputs, target):
3636 output = self .module (* module_inputs )
3737 output = output .view (- 1 , output .size (2 ))
3838 target = target .view (- 1 )
39+ output = nn .functional .log_softmax (output , - 1 )
40+ # make sure criterion is not from_logits
3941 loss = self .criterion (output , target ).view (1 , 1 )
42+ nll = nn .functional .nll_loss (
43+ output , target , ignore_index = self .ignore_index , reduction = 'sum' )
4044 if self .get_accuracy :
4145 _ , argmax = output .max (- 1 )
4246 invalid_targets = target .eq (self .ignore_index )
4347 accuracy = argmax .eq (target ).masked_fill_ (
4448 invalid_targets , 0 ).long ().sum ()
45- return loss , accuracy .view (1 , 1 )
49+ return loss , nll , accuracy .view (1 , 1 )
4650 else :
47- return loss
51+ return loss , nll
4852
4953
5054def _chunk_tuple (seq_tuple , num_chunks , batch_first = True ):
@@ -116,7 +120,7 @@ def __init__(self, model, regime=None,
116120 super (Seq2SeqTrainer , self ).__init__ ()
117121 self .model = model
118122 self .criterion = criterion or CrossEntropyLoss (
119- ignore_index = PAD , smooth_eps = label_smoothing , reduction = 'sum' )
123+ ignore_index = PAD , smooth_eps = label_smoothing , reduction = 'sum' , from_logits = False )
120124
121125 self .optimizer = OptimRegime (self .model , regime = regime )
122126 self .grad_clip = grad_clip
@@ -161,6 +165,7 @@ def batch_first(self):
161165 def iterate (self , src_tuple_batch , target_tuple_batch , training = True , chunk_batch = 1 ):
162166 loss_measure = 0
163167 accuracy_measure = 0
168+ nll_measure = 0
164169 num_words = 0
165170 if training :
166171 self .optimizer .zero_grad ()
@@ -194,7 +199,7 @@ def iterate(self, src_tuple_batch, target_tuple_batch, training=True, chunk_batc
194199 if training :
195200 self .optimizer .pre_forward ()
196201 # compute output
197- loss , accuracy = self .model_with_loss (inputs , target_labels )
202+ loss , nll , accuracy = self .model_with_loss (inputs , target_labels )
198203
199204 loss = loss .sum ()
200205 loss_measure += float (loss / num_words )
@@ -203,6 +208,7 @@ def iterate(self, src_tuple_batch, target_tuple_batch, training=True, chunk_batc
203208 else :
204209 loss /= target .size (batch_dim )
205210 accuracy_measure += float (accuracy .sum ().float () / num_words )
211+ nll_measure += float (nll .sum () / num_words )
206212
207213 if training :
208214 self .optimizer .pre_backward ()
@@ -231,7 +237,7 @@ def iterate(self, src_tuple_batch, target_tuple_batch, training=True, chunk_batc
231237 clip_grad_norm_ (self .model .decoder .embedder .parameters (),
232238 self .embedding_grad_clip )
233239 self .optimizer .step ()
234- return loss_measure , accuracy_measure , num_words
240+ return loss_measure , nll_measure , accuracy_measure , num_words
235241
236242 def _feed_data (self , data_loader , num_iterations = None , training = True , chunk_batch = 1 ):
237243 if training :
@@ -261,13 +267,13 @@ def _feed_data(self, data_loader, num_iterations=None, training=True, chunk_batc
261267 # update optimizer according to epoch and steps
262268 self .optimizer .update (self .epoch , self .training_steps )
263269 # do a train/evaluate iteration
264- loss , acc , num_words = self .iterate (src , target ,
265- training = training ,
266- chunk_batch = chunk_batch )
270+ loss , nll , acc , num_words = self .iterate (src , target ,
271+ training = training ,
272+ chunk_batch = chunk_batch )
267273
268274 # measure accuracy and record loss
269275 losses .update (loss , num_words )
270- perplexity .update (math .exp (loss ), num_words )
276+ perplexity .update (math .exp (nll ), num_words )
271277 accuracy .update (acc , num_words )
272278
273279 # measure elapsed time
@@ -470,8 +476,6 @@ def __init__(self, *kargs, **kwargs):
470476 _ , target_tok = self .save_info ['tokenizers' ].values ()
471477 target_words = target_tok .common_words (8188 )
472478 self .contrast_batch = batch_nested_sequences (target_words )
473- import pdb
474- pdb .set_trace ()
475479
476480 def iterate (self , src_tuple , target_tuple , training = True ):
477481 # limit number of tokens to avoid gpu overload
@@ -499,7 +503,7 @@ def iterate(self, src_tuple, target_tuple, training=True):
499503 target_labels = target [1 :]
500504
501505 # compute output
502- loss , accuracy = self .model_with_loss (inputs , target_labels )
506+ loss , nll , accuracy = self .model_with_loss (inputs , target_labels )
503507
504508 loss = loss .sum ()
505509 loss_measure = float (loss / num_words )
0 commit comments