|
2 | 2 | Pretraining script for MPNet |
3 | 3 | """ |
4 | 4 |
|
| 5 | +import argparse |
| 6 | +import gc |
5 | 7 | import logging |
| 8 | +import math |
| 9 | +import os |
6 | 10 | import sys |
| 11 | +from argparse import Namespace |
7 | 12 |
|
8 | 13 | from rich.logging import RichHandler |
9 | 14 |
|
|
13 | 18 | ) |
14 | 19 | LOGGER = logging.getLogger(__name__) |
15 | 20 |
|
16 | | -import argparse |
17 | | -import gc |
18 | | -import math |
19 | | -import os |
20 | 21 |
|
21 | 22 | import torch |
22 | 23 | import torch.nn.functional as F |
23 | 24 | from datasets import load_dataset |
24 | 25 | from rich.progress import track |
| 26 | +from torch.serialization import safe_globals |
25 | 27 | from torch.utils.tensorboard import SummaryWriter |
26 | 28 | from transformers import AutoTokenizer |
27 | 29 |
|
@@ -397,7 +399,7 @@ def main(args) -> None: |
397 | 399 | and steps > 0 |
398 | 400 | ): |
399 | 401 | torch.save( |
400 | | - {"args": args, "model_states": model.state_dict()}, |
| 402 | + {"args": vars(args), "model_states": model.state_dict()}, |
401 | 403 | os.path.join(args.checkpoint_dir, f"checkpoint{steps + 1}.pt"), |
402 | 404 | ) |
403 | 405 |
|
@@ -595,7 +597,7 @@ def main(args) -> None: |
595 | 597 |
|
596 | 598 | # Now let's go ahead and save this in the checkpoints directory |
597 | 599 | torch.save( |
598 | | - {"args": args, "model_states": model.state_dict()}, |
| 600 | + {"args": vars(args), "model_states": model.state_dict()}, |
599 | 601 | os.path.join(args.checkpoint_dir, "best_checkpoint.pt"), |
600 | 602 | ) |
601 | 603 |
|
@@ -630,13 +632,23 @@ def main(args) -> None: |
630 | 632 | # use the test dataloader we built above to get a final test metric using the best checkpoint |
631 | 633 |
|
632 | 634 | # Begin by loading the model states and args from the best checkpoint |
633 | | - dicts = torch.load(os.path.join(args.checkpoint_dir, "best_checkpoint.pt")) |
| 635 | + with safe_globals([Namespace]): |
| 636 | + dicts = torch.load(os.path.join(args.checkpoint_dir, "best_checkpoint.pt")) |
| 637 | + |
| 638 | + # Handle args that might be dict or Namespace |
| 639 | + loaded_args = dicts["args"] |
| 640 | + if isinstance(loaded_args, dict): |
| 641 | + loaded_args = Namespace(**loaded_args) |
| 642 | + |
| 643 | + # Handle potential _orig_mod prefix in state dict from compiled models |
| 644 | + model_states = dicts["model_states"] |
| 645 | + model_states = {k.replace("_orig_mod.", ""): v for k, v in model_states.items()} |
634 | 646 |
|
635 | 647 | # Load an empty shell of the model architecture using those args |
636 | | - test_model = MPNetForPretraining(dicts["args"], tokenizer) |
| 648 | + test_model = MPNetForPretraining(loaded_args, tokenizer) |
637 | 649 |
|
638 | 650 | # Now apply the model states to this newly instantiated model |
639 | | - test_model.load_state_dict(dicts["model_states"]) |
| 651 | + test_model.load_state_dict(model_states) |
640 | 652 |
|
641 | 653 | # Finally make sure the model is in eval mode and is sent to the proper device |
642 | 654 | test_model.to(device) |
|
0 commit comments