|
28 | 28 | from torch.utils.tensorboard import SummaryWriter |
29 | 29 | from transformers import AutoTokenizer |
30 | 30 |
|
| 31 | +import wandb |
31 | 32 | from annotated_mpnet.data import ( |
32 | | - DataCollatorForMaskedPermutedLanguageModeling, |
33 | | - HFStreamingDataset, |
34 | | - MPNetDataset, |
35 | | - RandomSamplerWithSeed, |
36 | | -) |
| 33 | + DataCollatorForMaskedPermutedLanguageModeling, HFStreamingDataset, |
| 34 | + MPNetDataset, RandomSamplerWithSeed) |
37 | 35 | from annotated_mpnet.modeling import MPNetForPretraining |
38 | 36 | from annotated_mpnet.scheduler import PolynomialDecayLRScheduler |
39 | 37 | from annotated_mpnet.tracking import AverageMeter |
40 | | -from annotated_mpnet.utils.utils import SUPPORTED_ACTIVATIONS, validate_tokenizer |
| 38 | +from annotated_mpnet.utils.utils import (SUPPORTED_ACTIVATIONS, |
| 39 | + validate_tokenizer) |
41 | 40 |
|
42 | 41 |
|
43 | 42 | def accuracy(output: torch.Tensor, target: torch.Tensor) -> int: |
@@ -71,6 +70,22 @@ def write_to_tensorboard(writer: SummaryWriter, logging_dict: dict, step: int) - |
71 | 70 | writer.add_scalar(stat_name, stat, step) |
72 | 71 |
|
73 | 72 |
|
| 73 | +def log_to_wandb(logging_dict: dict, step: int, split: str) -> None: |
| 74 | + """ |
| 75 | + Log metrics to Weights & Biases |
| 76 | +
|
| 77 | + Args: |
| 78 | + logging_dict: the dictionary containing the stats |
| 79 | + step: the current step |
| 80 | + split: the data split (train, valid, test) |
| 81 | + """ |
| 82 | + if wandb.run is not None: |
| 83 | + # Prefix metrics with split name for better organization in the dashboard |
| 84 | + wandb_dict = {f"{split}/{k}": v for k, v in logging_dict.items()} |
| 85 | + wandb_dict["step"] = step |
| 86 | + wandb.log(wandb_dict) |
| 87 | + |
| 88 | + |
74 | 89 | def check_and_activate_tf32(): |
75 | 90 | """ |
76 | 91 | Check if the GPU supports NVIDIA Ampere or later and enable FP32 in PyTorch if it does. |
@@ -145,9 +160,9 @@ def main(args) -> None: |
145 | 160 | args.tokenizer_name, model_max_length=args.max_tokens |
146 | 161 | ) |
147 | 162 | is_valid, details = validate_tokenizer(tokenizer) |
148 | | - assert ( |
149 | | - is_valid and details["whole_word_mask"] |
150 | | - ), f"Invalid tokenizer: {args.tokenizer_name}. Debug w/ verbose output from validate_tokenizer()" |
| 163 | + assert is_valid and details["whole_word_mask"], ( |
| 164 | + f"Invalid tokenizer: {args.tokenizer_name}. Debug w/ verbose output from validate_tokenizer()" |
| 165 | + ) |
151 | 166 |
|
152 | 167 | # Check and adjust model vocab_size for better GPU performance |
153 | 168 | original_vocab_size = tokenizer.vocab_size |
@@ -181,6 +196,19 @@ def main(args) -> None: |
181 | 196 | model = MPNetForPretraining(args, tokenizer) |
182 | 197 | mplm = DataCollatorForMaskedPermutedLanguageModeling(tokenizer=tokenizer) |
183 | 198 |
|
| 199 | + # Initialize wandb if enabled (after model creation) |
| 200 | + if args.wandb: |
| 201 | + wandb.init( |
| 202 | + project=args.wandb_project, |
| 203 | + name=args.wandb_name, |
| 204 | + config=vars(args), |
| 205 | + resume="allow", |
| 206 | + id=args.wandb_id, |
| 207 | + ) |
| 208 | + # Log model architecture as a graph |
| 209 | + if args.wandb_watch: |
| 210 | + wandb.watch(model, log_freq=100) |
| 211 | + |
184 | 212 | # sync args for relative attention with model |
185 | 213 | args.relative_attention_num_buckets = ( |
186 | 214 | model.sentence_encoder.relative_attention_num_buckets |
@@ -579,6 +607,10 @@ def main(args) -> None: |
579 | 607 | else: |
580 | 608 | LOGGER.info(logging_dict) |
581 | 609 |
|
| 610 | + # Log to wandb if enabled |
| 611 | + if args.wandb: |
| 612 | + log_to_wandb(logging_dict, steps, "train") |
| 613 | + |
582 | 614 | # Reset accumulation counters here for the next set of accumulation steps |
583 | 615 | accumulation_acc = 0 |
584 | 616 | accumulation_loss = 0 |
@@ -660,6 +692,10 @@ def main(args) -> None: |
660 | 692 | LOGGER.info("Validation stats:") |
661 | 693 | LOGGER.info(logging_dict) |
662 | 694 |
|
| 695 | + # Log to wandb if enabled |
| 696 | + if args.wandb: |
| 697 | + log_to_wandb(logging_dict, steps, "valid") |
| 698 | + |
663 | 699 | # Now, before looping back, we increment the epoch counter and we delete the train data |
664 | 700 | # loader and garbage collect it |
665 | 701 | epoch += 1 |
@@ -756,11 +792,19 @@ def main(args) -> None: |
756 | 792 | LOGGER.info("Test stats:") |
757 | 793 | LOGGER.info(logging_dict) |
758 | 794 |
|
| 795 | + # Log to wandb if enabled |
| 796 | + if args.wandb: |
| 797 | + log_to_wandb(logging_dict, steps, "test") |
| 798 | + |
759 | 799 | LOGGER.info( |
760 | 800 | f"Training is finished! See output in {args.checkpoint_dir} and " |
761 | 801 | f"tensorboard logs in {args.tensorboard_log_dir}" |
762 | 802 | ) |
763 | 803 |
|
| 804 | + # Finish wandb run if active |
| 805 | + if args.wandb and wandb.run is not None: |
| 806 | + wandb.finish() |
| 807 | + |
764 | 808 |
|
765 | 809 | def cli_main(): |
766 | 810 | """ |
@@ -1050,6 +1094,38 @@ def cli_main(): |
1050 | 1094 | default=False, |
1051 | 1095 | ) |
1052 | 1096 |
|
| 1097 | + # Weights & Biases arguments |
| 1098 | + parser.add_argument( |
| 1099 | + "--wandb", |
| 1100 | + help="Whether to use Weights & Biases for logging", |
| 1101 | + action="store_true", |
| 1102 | + default=False, |
| 1103 | + ) |
| 1104 | + parser.add_argument( |
| 1105 | + "--wandb-project", |
| 1106 | + help="Weights & Biases project name", |
| 1107 | + default="annotated-mpnet", |
| 1108 | + type=str, |
| 1109 | + ) |
| 1110 | + parser.add_argument( |
| 1111 | + "--wandb-name", |
| 1112 | + help="Weights & Biases run name", |
| 1113 | + default=None, |
| 1114 | + type=str, |
| 1115 | + ) |
| 1116 | + parser.add_argument( |
| 1117 | + "--wandb-id", |
| 1118 | + help="Weights & Biases run ID for resuming a run", |
| 1119 | + default=None, |
| 1120 | + type=str, |
| 1121 | + ) |
| 1122 | + parser.add_argument( |
| 1123 | + "--wandb-watch", |
| 1124 | + help="Whether to log model gradients in Weights & Biases", |
| 1125 | + action="store_true", |
| 1126 | + default=False, |
| 1127 | + ) |
| 1128 | + |
1053 | 1129 | args = parser.parse_args() |
1054 | 1130 |
|
1055 | 1131 | # Check for validity of arguments |
|
0 commit comments