Skip to content

Commit 48c8046

Browse files
author
Raymond Li
committed
fix the merge
1 parent 203b071 commit 48c8046

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

megatron/model/transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,8 @@ def __init__(self, init_method,
577577
self.attention_head_type = args.attention_head_type
578578
self.sequence_parallel = args.sequence_parallel
579579

580+
self.use_flash_attn = args.use_flash_attn
581+
580582
projection_size = args.kv_channels * args.num_attention_heads
581583

582584
# Per attention head and per partition values.

megatron/training.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
564564

565565
total_iterations = total_loss_dict[advanced_iters_key] + \
566566
total_loss_dict[skipped_iters_key]
567+
mem_stats = None
567568

568569
# Tensorboard values.
569570
# Timer requires all the ranks to call.
@@ -665,6 +666,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
665666
log_string += ' number of nan iterations: {:3d} |'.format(
666667
total_loss_dict[nan_iters_key])
667668
log_string += ' TFLOPs: {:.2f} |'.format(tflops)
669+
if args.log_memory_to_tensorboard and mem_stats is not None:
670+
log_string += ' mem-reserved (GB): {:.2f} |'.format(mem_stats["reserved_bytes.all.current"]*1e-9)
668671
total_loss_dict[advanced_iters_key] = 0
669672
total_loss_dict[skipped_iters_key] = 0
670673
total_loss_dict[nan_iters_key] = 0
@@ -1023,14 +1026,14 @@ def build_train_valid_test_data_loaders(
10231026
mpu.get_tensor_model_parallel_src_rank(),
10241027
group=mpu.get_tensor_model_parallel_group())
10251028
args.do_train = flags[0].item()
1026-
num_valid_ds = flags[1].item()
1027-
num_test_ds = flags[2].item()
1028-
assert num_test_ds >= 0
1029-
assert num_valid_ds >= 0
1030-
args.do_valid = num_valid_ds > 0
1031-
args.do_test = num_test_ds > 0
1029+
args.num_valid_ds = flags[1].item()
1030+
args.num_test_ds = flags[2].item()
1031+
assert args.num_test_ds >= 0
1032+
assert args.num_valid_ds >= 0
1033+
args.do_valid = args.num_valid_ds > 0
1034+
args.do_test = args.num_test_ds > 0
10321035

1033-
return train_dataloader, valid_dataloader, test_dataloader
1036+
return train_dataloader, valid_dataloaders, test_dataloaders
10341037

10351038

10361039
def build_train_valid_test_data_iterators(
@@ -1039,7 +1042,7 @@ def build_train_valid_test_data_iterators(
10391042
args = get_args()
10401043

10411044
# Build loaders.
1042-
train_dataloader, valid_dataloader, test_dataloader = \
1045+
train_dataloader, valid_dataloaders, test_dataloaders = \
10431046
build_train_valid_test_data_loaders(
10441047
build_train_valid_test_datasets_provider)
10451048

@@ -1058,13 +1061,13 @@ def build_train_valid_test_data_iterators(
10581061
else iter(cyclic_iter(valid_dataloaders))
10591062
for vdl in valid_dataloaders]
10601063
else:
1061-
valid_data_iterators = [None] * num_valid_ds
1064+
valid_data_iterators = [None] * args.num_valid_ds
10621065

10631066
if test_dataloaders is not None:
10641067
test_data_iterators = [iter(tdl) if dl_type == 'single' \
10651068
else iter(cyclic_iter(test_dataloaders))
10661069
for tdl in test_dataloaders]
10671070
else:
1068-
test_data_iterators = [None] * num_test_ds
1071+
test_data_iterators = [None] * args.num_test_ds
10691072

10701073
return train_data_iterator, valid_data_iterators, test_data_iterators

0 commit comments

Comments
 (0)