Skip to content

Commit 25604f4

Browse files
authored
wandb support
1 parent bc83804 commit 25604f4

File tree

2 files changed

+86
-9
lines changed

2 files changed

+86
-9
lines changed

cli_tools/pretrain_mpnet.py

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,15 @@
2828
from torch.utils.tensorboard import SummaryWriter
2929
from transformers import AutoTokenizer
3030

31+
import wandb
3132
from annotated_mpnet.data import (
32-
DataCollatorForMaskedPermutedLanguageModeling,
33-
HFStreamingDataset,
34-
MPNetDataset,
35-
RandomSamplerWithSeed,
36-
)
33+
DataCollatorForMaskedPermutedLanguageModeling, HFStreamingDataset,
34+
MPNetDataset, RandomSamplerWithSeed)
3735
from annotated_mpnet.modeling import MPNetForPretraining
3836
from annotated_mpnet.scheduler import PolynomialDecayLRScheduler
3937
from annotated_mpnet.tracking import AverageMeter
40-
from annotated_mpnet.utils.utils import SUPPORTED_ACTIVATIONS, validate_tokenizer
38+
from annotated_mpnet.utils.utils import (SUPPORTED_ACTIVATIONS,
39+
validate_tokenizer)
4140

4241

4342
def accuracy(output: torch.Tensor, target: torch.Tensor) -> int:
@@ -71,6 +70,22 @@ def write_to_tensorboard(writer: SummaryWriter, logging_dict: dict, step: int) -
7170
writer.add_scalar(stat_name, stat, step)
7271

7372

73+
def log_to_wandb(logging_dict: dict, step: int, split: str) -> None:
74+
"""
75+
Log metrics to Weights & Biases
76+
77+
Args:
78+
logging_dict: the dictionary containing the stats
79+
step: the current step
80+
split: the data split (train, valid, test)
81+
"""
82+
if wandb.run is not None:
83+
# Prefix metrics with split name for better organization in the dashboard
84+
wandb_dict = {f"{split}/{k}": v for k, v in logging_dict.items()}
85+
wandb_dict["step"] = step
86+
wandb.log(wandb_dict)
87+
88+
7489
def check_and_activate_tf32():
7590
"""
7691
Check if the GPU supports NVIDIA Ampere or later and enable FP32 in PyTorch if it does.
@@ -145,9 +160,9 @@ def main(args) -> None:
145160
args.tokenizer_name, model_max_length=args.max_tokens
146161
)
147162
is_valid, details = validate_tokenizer(tokenizer)
148-
assert (
149-
is_valid and details["whole_word_mask"]
150-
), f"Invalid tokenizer: {args.tokenizer_name}. Debug w/ verbose output from validate_tokenizer()"
163+
assert is_valid and details["whole_word_mask"], (
164+
f"Invalid tokenizer: {args.tokenizer_name}. Debug w/ verbose output from validate_tokenizer()"
165+
)
151166

152167
# Check and adjust model vocab_size for better GPU performance
153168
original_vocab_size = tokenizer.vocab_size
@@ -181,6 +196,19 @@ def main(args) -> None:
181196
model = MPNetForPretraining(args, tokenizer)
182197
mplm = DataCollatorForMaskedPermutedLanguageModeling(tokenizer=tokenizer)
183198

199+
# Initialize wandb if enabled (after model creation)
200+
if args.wandb:
201+
wandb.init(
202+
project=args.wandb_project,
203+
name=args.wandb_name,
204+
config=vars(args),
205+
resume="allow",
206+
id=args.wandb_id,
207+
)
208+
# Log model architecture as a graph
209+
if args.wandb_watch:
210+
wandb.watch(model, log_freq=100)
211+
184212
# sync args for relative attention with model
185213
args.relative_attention_num_buckets = (
186214
model.sentence_encoder.relative_attention_num_buckets
@@ -579,6 +607,10 @@ def main(args) -> None:
579607
else:
580608
LOGGER.info(logging_dict)
581609

610+
# Log to wandb if enabled
611+
if args.wandb:
612+
log_to_wandb(logging_dict, steps, "train")
613+
582614
# Reset accumulation counters here for the next set of accumulation steps
583615
accumulation_acc = 0
584616
accumulation_loss = 0
@@ -660,6 +692,10 @@ def main(args) -> None:
660692
LOGGER.info("Validation stats:")
661693
LOGGER.info(logging_dict)
662694

695+
# Log to wandb if enabled
696+
if args.wandb:
697+
log_to_wandb(logging_dict, steps, "valid")
698+
663699
# Now, before looping back, we increment the epoch counter and we delete the train data
664700
# loader and garbage collect it
665701
epoch += 1
@@ -756,11 +792,19 @@ def main(args) -> None:
756792
LOGGER.info("Test stats:")
757793
LOGGER.info(logging_dict)
758794

795+
# Log to wandb if enabled
796+
if args.wandb:
797+
log_to_wandb(logging_dict, steps, "test")
798+
759799
LOGGER.info(
760800
f"Training is finished! See output in {args.checkpoint_dir} and "
761801
f"tensorboard logs in {args.tensorboard_log_dir}"
762802
)
763803

804+
# Finish wandb run if active
805+
if args.wandb and wandb.run is not None:
806+
wandb.finish()
807+
764808

765809
def cli_main():
766810
"""
@@ -1050,6 +1094,38 @@ def cli_main():
10501094
default=False,
10511095
)
10521096

1097+
# Weights & Biases arguments
1098+
parser.add_argument(
1099+
"--wandb",
1100+
help="Whether to use Weights & Biases for logging",
1101+
action="store_true",
1102+
default=False,
1103+
)
1104+
parser.add_argument(
1105+
"--wandb-project",
1106+
help="Weights & Biases project name",
1107+
default="annotated-mpnet",
1108+
type=str,
1109+
)
1110+
parser.add_argument(
1111+
"--wandb-name",
1112+
help="Weights & Biases run name",
1113+
default=None,
1114+
type=str,
1115+
)
1116+
parser.add_argument(
1117+
"--wandb-id",
1118+
help="Weights & Biases run ID for resuming a run",
1119+
default=None,
1120+
type=str,
1121+
)
1122+
parser.add_argument(
1123+
"--wandb-watch",
1124+
help="Whether to log model gradients in Weights & Biases",
1125+
action="store_true",
1126+
default=False,
1127+
)
1128+
10531129
args = parser.parse_args()
10541130

10551131
# Check for validity of arguments

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def include_dirs(self, dirs):
7575
"tensorboard",
7676
"torch>=2.6.0",
7777
"transformers",
78+
"wandb",
7879
],
7980
packages=find_packages(exclude=["cli_tools", "tests"]),
8081
ext_modules=extensions,

0 commit comments

Comments
 (0)