@@ -564,6 +564,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
564564
565565 total_iterations = total_loss_dict [advanced_iters_key ] + \
566566 total_loss_dict [skipped_iters_key ]
567+ mem_stats = None
567568
568569 # Tensorboard values.
569570 # Timer requires all the ranks to call.
@@ -665,6 +666,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
665666 log_string += ' number of nan iterations: {:3d} |' .format (
666667 total_loss_dict [nan_iters_key ])
667668 log_string += ' TFLOPs: {:.2f} |' .format (tflops )
669+ if args .log_memory_to_tensorboard and mem_stats is not None :
670+ log_string += ' mem-reserved (GB): {:.2f} |' .format (mem_stats ["reserved_bytes.all.current" ]* 1e-9 )
668671 total_loss_dict [advanced_iters_key ] = 0
669672 total_loss_dict [skipped_iters_key ] = 0
670673 total_loss_dict [nan_iters_key ] = 0
@@ -1023,14 +1026,14 @@ def build_train_valid_test_data_loaders(
10231026 mpu .get_tensor_model_parallel_src_rank (),
10241027 group = mpu .get_tensor_model_parallel_group ())
10251028 args .do_train = flags [0 ].item ()
1026- num_valid_ds = flags [1 ].item ()
1027- num_test_ds = flags [2 ].item ()
1028- assert num_test_ds >= 0
1029- assert num_valid_ds >= 0
1030- args .do_valid = num_valid_ds > 0
1031- args .do_test = num_test_ds > 0
1029+ args . num_valid_ds = flags [1 ].item ()
1030+ args . num_test_ds = flags [2 ].item ()
1031+ assert args . num_test_ds >= 0
1032+ assert args . num_valid_ds >= 0
1033+ args .do_valid = args . num_valid_ds > 0
1034+ args .do_test = args . num_test_ds > 0
10321035
1033- return train_dataloader , valid_dataloader , test_dataloader
1036+ return train_dataloader , valid_dataloaders , test_dataloaders
10341037
10351038
10361039def build_train_valid_test_data_iterators (
@@ -1039,7 +1042,7 @@ def build_train_valid_test_data_iterators(
10391042 args = get_args ()
10401043
10411044 # Build loaders.
1042- train_dataloader , valid_dataloader , test_dataloader = \
1045+ train_dataloader , valid_dataloaders , test_dataloaders = \
10431046 build_train_valid_test_data_loaders (
10441047 build_train_valid_test_datasets_provider )
10451048
@@ -1058,13 +1061,13 @@ def build_train_valid_test_data_iterators(
10581061 else iter (cyclic_iter (valid_dataloaders ))
10591062 for vdl in valid_dataloaders ]
10601063 else :
1061- valid_data_iterators = [None ] * num_valid_ds
1064+ valid_data_iterators = [None ] * args . num_valid_ds
10621065
10631066 if test_dataloaders is not None :
10641067 test_data_iterators = [iter (tdl ) if dl_type == 'single' \
10651068 else iter (cyclic_iter (test_dataloaders ))
10661069 for tdl in test_dataloaders ]
10671070 else :
1068- test_data_iterators = [None ] * num_test_ds
1071+ test_data_iterators = [None ] * args . num_test_ds
10691072
10701073 return train_data_iterator , valid_data_iterators , test_data_iterators
0 commit comments