Skip to content

Commit 4a0fcfa

Browse files
pszemrajPeter Szemraj
andauthored
Save load fix (#8)
fixes serialization issues (save/load) both in terms of safe globals etc, and also working with a compiled model --------- Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> Co-authored-by: Peter Szemraj <peterszemraj+dev@gmail.com>
1 parent acab08f commit 4a0fcfa

File tree

3 files changed

+39
-15
lines changed

3 files changed

+39
-15
lines changed

cli_tools/convert_pretrained_mpnet_to_hf_model.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
doing here.
77
"""
88

9+
import argparse
910
import logging
11+
import pathlib
12+
from argparse import Namespace
1013

1114
from rich.logging import RichHandler
1215

@@ -16,10 +19,9 @@
1619
)
1720
LOGGER = logging.getLogger(__name__)
1821

19-
import argparse
20-
import pathlib
2122

2223
import torch
24+
from torch.serialization import safe_globals
2325
from transformers import MPNetConfig, MPNetForMaskedLM
2426
from transformers.models.mpnet import MPNetLayer
2527
from transformers.utils import logging as hf_logging
@@ -45,12 +47,21 @@ def convert_mpnet_checkpoint_to_pytorch(
4547

4648
# Load up the state dicts (one for the weights and one for the args) from the provided
4749
# serialization path
48-
state_dicts = torch.load(mpnet_checkpoint_path)
50+
with safe_globals([Namespace]):
51+
state_dicts = torch.load(mpnet_checkpoint_path)
4952

5053
# Extract the model args so that we can properly set the config later on
5154
# Extract the weights so we can set them within the constructs of the model
5255
mpnet_args = state_dicts["args"]
56+
if isinstance(mpnet_args, dict):
57+
mpnet_args = Namespace(**mpnet_args)
58+
5359
mpnet_weight = state_dicts["model_states"]
60+
# Fix for torch.compile() _orig_mod prefix
61+
mpnet_weight = {k.replace("_orig_mod.", ""): v for k, v in mpnet_weight.items()}
62+
63+
print("Keys after removing _orig_mod prefix (if present):")
64+
print(list(mpnet_weight.keys())[:5]) # Print first few keys to verify
5465

5566
# Now we use the args (and one componennt of the weight to get the vocab size) to set the
5667
# MPNetConfig object, which will properly instantiate the MPNetForMaskedLM model to the specs
@@ -205,8 +216,9 @@ def cli_main():
205216
"""
206217
Wrapper function so we can define a CLI entrypoint when setting up this package
207218
"""
208-
parser = argparse.ArgumentParser()
209-
# Required parameters
219+
parser = argparse.ArgumentParser(
220+
description="Convert MPNet .pt checkpoint to Huggingface model"
221+
)
210222
parser.add_argument(
211223
"--mpnet-checkpoint-path",
212224
default=None,

cli_tools/pretrain_mpnet.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@
22
Pretraining script for MPNet
33
"""
44

5+
import argparse
6+
import gc
57
import logging
8+
import math
9+
import os
610
import sys
11+
from argparse import Namespace
712

813
from rich.logging import RichHandler
914

@@ -13,15 +18,12 @@
1318
)
1419
LOGGER = logging.getLogger(__name__)
1520

16-
import argparse
17-
import gc
18-
import math
19-
import os
2021

2122
import torch
2223
import torch.nn.functional as F
2324
from datasets import load_dataset
2425
from rich.progress import track
26+
from torch.serialization import safe_globals
2527
from torch.utils.tensorboard import SummaryWriter
2628
from transformers import AutoTokenizer
2729

@@ -397,7 +399,7 @@ def main(args) -> None:
397399
and steps > 0
398400
):
399401
torch.save(
400-
{"args": args, "model_states": model.state_dict()},
402+
{"args": vars(args), "model_states": model.state_dict()},
401403
os.path.join(args.checkpoint_dir, f"checkpoint{steps + 1}.pt"),
402404
)
403405

@@ -595,7 +597,7 @@ def main(args) -> None:
595597

596598
# Now let's go ahead and save this in the checkpoints directory
597599
torch.save(
598-
{"args": args, "model_states": model.state_dict()},
600+
{"args": vars(args), "model_states": model.state_dict()},
599601
os.path.join(args.checkpoint_dir, "best_checkpoint.pt"),
600602
)
601603

@@ -630,13 +632,23 @@ def main(args) -> None:
630632
# use the test dataloader we built above to get a final test metric using the best checkpoint
631633

632634
# Begin by loading the model states and args from the best checkpoint
633-
dicts = torch.load(os.path.join(args.checkpoint_dir, "best_checkpoint.pt"))
635+
with safe_globals([Namespace]):
636+
dicts = torch.load(os.path.join(args.checkpoint_dir, "best_checkpoint.pt"))
637+
638+
# Handle args that might be dict or Namespace
639+
loaded_args = dicts["args"]
640+
if isinstance(loaded_args, dict):
641+
loaded_args = Namespace(**loaded_args)
642+
643+
# Handle potential _orig_mod prefix in state dict from compiled models
644+
model_states = dicts["model_states"]
645+
model_states = {k.replace("_orig_mod.", ""): v for k, v in model_states.items()}
634646

635647
# Load an empty shell of the model architecture using those args
636-
test_model = MPNetForPretraining(dicts["args"], tokenizer)
648+
test_model = MPNetForPretraining(loaded_args, tokenizer)
637649

638650
# Now apply the model states to this newly instantiated model
639-
test_model.load_state_dict(dicts["model_states"])
651+
test_model.load_state_dict(model_states)
640652

641653
# Finally make sure the model is in eval mode and is sent to the proper device
642654
test_model.to(device)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def include_dirs(self, dirs):
5757

5858
setup(
5959
name="annotated_mpnet",
60-
version="0.1.1",
60+
version="0.1.2",
6161
description="Raw Torch, heavily annotated, pretrainable MPNet",
6262
url="https://github.com/pszemraj/annotated-mpnet",
6363
long_description=readme,

0 commit comments

Comments
 (0)