Skip to content

Commit 719eaa6

Browse files
authored
Update README example
1 parent 5b2d8ba commit 719eaa6

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ class ConditionalGPT2Model(nn.Module):
5151
# Use hyperparameter dict for model configuration
5252
self.embedder = tx.modules.WordEmbedder(vocab_size, hparams=emb_hparams)
5353
self.encoder = tx.modules.TransformerEncoder(hparams=enc_hparams)
54-
# GPT-2 module with pre-trained weights
55-
self.decoder = tx.modules.GPT2Decoder("gpt2-small")
54+
self.decoder = tx.modules.GPT2Decoder("gpt2-small") # With pre-trained weights
5655

5756
def _get_decoder_output(self, batch, train=True):
5857
"""Perform model inference, i.e., decoding."""
@@ -71,8 +70,7 @@ class ConditionalGPT2Model(nn.Module):
7170
def forward(self, batch):
7271
"""Compute training loss."""
7372
outputs = self._get_decoder_output(batch)
74-
# Loss for maximum likelihood learning
75-
loss = tx.losses.sequence_sparse_softmax_cross_entropy(
73+
loss = tx.losses.sequence_sparse_softmax_cross_entropy( # Sequence loss
7674
labels=batch['target_text_ids'][:, 1:], logits=outputs.logits,
7775
sequence_length=batch['target_length'] - 1) # Automatic masking
7876
return {"loss": loss}

0 commit comments

Comments
 (0)