File tree Expand file tree Collapse file tree 1 file changed +2
-4
lines changed
Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change @@ -51,8 +51,7 @@ class ConditionalGPT2Model(nn.Module):
5151 # Use hyperparameter dict for model configuration
5252 self .embedder = tx.modules.WordEmbedder(vocab_size, hparams = emb_hparams)
5353 self .encoder = tx.modules.TransformerEncoder(hparams = enc_hparams)
54- # GPT-2 module with pre-trained weights
55- self .decoder = tx.modules.GPT2Decoder(" gpt2-small" )
54+ self .decoder = tx.modules.GPT2Decoder(" gpt2-small" ) # With pre-trained weights
5655
5756 def _get_decoder_output (self , batch , train = True ):
5857 """ Perform model inference, i.e., decoding."""
@@ -71,8 +70,7 @@ class ConditionalGPT2Model(nn.Module):
7170 def forward (self , batch ):
7271 """ Compute training loss."""
7372 outputs = self ._get_decoder_output(batch)
74- # Loss for maximum likelihood learning
75- loss = tx.losses.sequence_sparse_softmax_cross_entropy(
73+ loss = tx.losses.sequence_sparse_softmax_cross_entropy( # Sequence loss
7674 labels = batch[' target_text_ids' ][:, 1 :], logits = outputs.logits,
7775 sequence_length = batch[' target_length' ] - 1 ) # Automatic masking
7876 return {" loss" : loss}
You can’t perform that action at this time.
0 commit comments