Skip to content

Commit 675b957

Browse files
pszemrajPeter Szemraj
andauthored
Config fixes (#13)
this PR validates important config attributes during training/export: - token_ids during training and export - dropout settings during export --------- Signed-off-by: peter szemraj <peterszemraj@gmail.com> Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> Co-authored-by: Peter Szemraj <peterszemraj+dev@gmail.com>
1 parent 25604f4 commit 675b957

File tree

3 files changed

+42
-16
lines changed

3 files changed

+42
-16
lines changed

annotated_mpnet/transformer_modules/positional_embedding.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ def PositionalEmbedding(
3232

3333
# If we specified "learned" to be True, we want to create a learned positional embedding module
3434
if learned:
35-
# If we specify a padding index, we need to update the total number of embeddings
36-
if padding_idx is not None:
37-
num_embeddings = num_embeddings + padding_idx + 1
35+
num_embeddings = num_embeddings + 2 # Add 2 for CLS and SEP
3836

3937
# Instantiate the learned positional embeddings
4038
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
@@ -48,7 +46,7 @@ def PositionalEmbedding(
4846
# Branch to create sinusoidal embeddings if "learned" is False
4947
else:
5048
m = SinusoidalPositionalEmbedding(
51-
embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1
49+
embedding_dim, padding_idx, init_size=num_embeddings + 2 # Add 2 for CLS and SEP
5250
)
5351

5452
return m

cli_tools/convert_pretrained_mpnet_to_hf_model.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def convert_mpnet_checkpoint_to_pytorch(
5050
# Load up the state dicts (one for the weights and one for the args) from the provided
5151
# serialization path
5252
with safe_globals([Namespace]):
53-
state_dicts = torch.load(mpnet_checkpoint_path)
53+
state_dicts = torch.load(mpnet_checkpoint_path, map_location="cpu")
5454

5555
# Extract the model args so that we can properly set the config later on
5656
# Extract the weights so we can set them within the constructs of the model
@@ -80,9 +80,20 @@ def convert_mpnet_checkpoint_to_pytorch(
8080
max_position_embeddings=mpnet_args.max_positions + 2,
8181
relative_attention_num_buckets=mpnet_args.relative_attention_num_buckets,
8282
hidden_act=mpnet_args.activation_fn,
83+
# Note: there are three dropouts in MPNetForPretraining, but only two in MPNetForMaskedLM
84+
hidden_dropout_prob=mpnet_args.activation_dropout,
85+
attention_probs_dropout_prob=mpnet_args.attention_dropout,
8386
layer_norm_eps=1e-5,
8487
)
8588

89+
# if the mpnet_args contain token_ids, ensure model config matches
90+
if hasattr(mpnet_args, "pad_token_id"):
91+
config.pad_token_id = mpnet_args.pad_token_id
92+
if hasattr(mpnet_args, "bos_token_id"):
93+
config.bos_token_id = mpnet_args.bos_token_id
94+
if hasattr(mpnet_args, "eos_token_id"):
95+
config.eos_token_id = mpnet_args.eos_token_id
96+
8697
# Now load the model with randomized weights
8798
model = MPNetForMaskedLM(config)
8899

@@ -210,17 +221,27 @@ def convert_mpnet_checkpoint_to_pytorch(
210221
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
211222
LOGGER.info(f"Saving model to {pytorch_dump_folder_path}")
212223

213-
# Now that the config and weights are loaded into the model class, we can use HF's builtin
214-
# save_pretrained function to dump the appropriate contents to the provided dir path
215-
model.save_pretrained(pytorch_dump_folder_path)
216-
217224
if save_tokenizer and hasattr(mpnet_args, "tokenizer_name"):
218225
LOGGER.info(f"Saving tokenizer to {pytorch_dump_folder_path}")
219226
tokenizer = AutoTokenizer.from_pretrained(
220227
mpnet_args.tokenizer_name, model_max_length=mpnet_args.max_positions
221228
)
229+
230+
# Synchronize token IDs between tokenizer and model config
231+
model.config.bos_token_id = tokenizer.bos_token_id
232+
model.config.eos_token_id = tokenizer.eos_token_id
233+
model.config.pad_token_id = tokenizer.pad_token_id
234+
LOGGER.info(
235+
f"Updated config with tokenizer IDs: BOS={tokenizer.bos_token_id}, "
236+
f"EOS={tokenizer.eos_token_id}, PAD={tokenizer.pad_token_id}"
237+
)
238+
222239
tokenizer.save_pretrained(pytorch_dump_folder_path)
223240

241+
# Now that the config and weights are loaded into the model class, we can use HF's builtin
242+
# save_pretrained function to dump the appropriate contents to the provided dir path
243+
model.save_pretrained(pytorch_dump_folder_path)
244+
224245
LOGGER.info("Done!")
225246

226247

cli_tools/pretrain_mpnet.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@
3030

3131
import wandb
3232
from annotated_mpnet.data import (
33-
DataCollatorForMaskedPermutedLanguageModeling, HFStreamingDataset,
34-
MPNetDataset, RandomSamplerWithSeed)
33+
DataCollatorForMaskedPermutedLanguageModeling,
34+
HFStreamingDataset,
35+
MPNetDataset,
36+
RandomSamplerWithSeed,
37+
)
3538
from annotated_mpnet.modeling import MPNetForPretraining
3639
from annotated_mpnet.scheduler import PolynomialDecayLRScheduler
3740
from annotated_mpnet.tracking import AverageMeter
38-
from annotated_mpnet.utils.utils import (SUPPORTED_ACTIVATIONS,
39-
validate_tokenizer)
41+
from annotated_mpnet.utils.utils import SUPPORTED_ACTIVATIONS, validate_tokenizer
4042

4143

4244
def accuracy(output: torch.Tensor, target: torch.Tensor) -> int:
@@ -160,9 +162,9 @@ def main(args) -> None:
160162
args.tokenizer_name, model_max_length=args.max_tokens
161163
)
162164
is_valid, details = validate_tokenizer(tokenizer)
163-
assert is_valid and details["whole_word_mask"], (
164-
f"Invalid tokenizer: {args.tokenizer_name}. Debug w/ verbose output from validate_tokenizer()"
165-
)
165+
assert (
166+
is_valid and details["whole_word_mask"]
167+
), f"Invalid tokenizer: {args.tokenizer_name}. Debug w/ verbose output from validate_tokenizer()"
166168

167169
# Check and adjust model vocab_size for better GPU performance
168170
original_vocab_size = tokenizer.vocab_size
@@ -182,6 +184,11 @@ def main(args) -> None:
182184
args.original_vocab_size = original_vocab_size
183185
args.padded_vocab_size = original_vocab_size
184186

187+
# Explicitly store token IDs in args for consistent usage
188+
args.pad_token_id = tokenizer.pad_token_id
189+
args.bos_token_id = tokenizer.bos_token_id
190+
args.eos_token_id = tokenizer.eos_token_id
191+
185192
# -----------------------------------
186193

187194
# Instantiate the tensorboard writers

0 commit comments

Comments
 (0)