Skip to content

Commit 659295a

Browse files
Kv grad allreduce v2 (#39)
Co-authored-by: thomasw21 <24695242+thomasw21@users.noreply.github.com>
1 parent e969456 commit 659295a

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

megatron/model/transformer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -611,12 +611,22 @@ def forward(self, hidden_states, attention_mask,
611611
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
612612
elif self.attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery':
613613
kv_input=hidden_states
614-
if get_args().sequence_parallel:
615-
# The linear layer doesn't gather the sequence-parallel.
616-
kv_input = mpu.gather_from_sequence_parallel_region(kv_input, tensor_parallel_output_grad=False)
617614
# Attention heads [sq, b, h] --> [sq, b, (2 * hn)]
618615
mixed_kv_layer = self.key_value(kv_input)
619616

617+
# Reduce the KV gradients in the tensor-parallel direction.
618+
# This is different from multi-head attention which reduces the KV input,
619+
# because the sum over attn heads happens in the attn weight gradient instead of the KV layer:
620+
# A [b, n * sq, sk] = Q [b, n * sq, hn] x K^T [b, hn, sk]
621+
# G_K [b, sk, hn] = G_A [b, sk, n * sq] x Q [b, n * sq, hn]
622+
# = sum_p (G_Ap [b, sk, np * sq] x Q_p [b, np * sq, hn])
623+
if get_args().sequence_parallel:
624+
# We switch to the tensor parallel regime here instead of at the KV input
625+
# so that the KV layer is done in parallel instead of just duplicated.
626+
mixed_kv_layer = mpu.gather_from_sequence_parallel_region(mixed_kv_layer, tensor_parallel_output_grad=True)
627+
else:
628+
mixed_kv_layer = mpu.copy_to_tensor_model_parallel_region(mixed_kv_layer)
629+
620630
# [sq, b, (2 * hn)] --> [sq, b, np (expanded), 2 * hn]
621631
# new_tensor_shape = mixed_kv_layer.size()[:-1] + \
622632
# (self.num_attention_heads_per_partition,

megatron/optimizer/distrib_optimizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,11 @@ def reduce_model_grads(self, args, timers):
542542
timers('backward-embedding-all-reduce').stop()
543543

544544
# All-reduce key-value grads if needed.
545-
if args.attention_head_type == "multiquery":
545+
if (
546+
args.attention_head_type == "multiquery"
547+
and mpu.get_tensor_model_parallel_world_size() > 1
548+
and args.sequence_parallel
549+
):
546550
timers('backward-key-value-all-reduce').start()
547551
self.allreduce_key_value_grads(args)
548552
timers('backward-key-value-all-reduce').stop()

megatron/optimizer/optimizer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ def allreduce_embedding_grads(self, args):
268268

269269
def allreduce_key_value_grads(self, args):
270270
"""
271-
Reduce the gradients for the key_value weights and biases for multi-query attention.
271+
Reduce the gradients for the key_value weights and biases for multi-query attention
272+
with sequence parallelism.
272273
Coalesce the bias grads to avoid too many small reductions,
273274
but not the weight grads since it could cause memory issues.
274275
"""
@@ -334,7 +335,11 @@ def reduce_model_grads(self, args, timers):
334335
timers('backward-embedding-all-reduce').stop()
335336

336337
# All-reduce key-value grads if needed.
337-
if args.attention_head_type == "multiquery":
338+
if (
339+
args.attention_head_type == "multiquery"
340+
and mpu.get_tensor_model_parallel_world_size() > 1
341+
and args.sequence_parallel
342+
):
338343
timers('backward-key-value-all-reduce').start()
339344
self.allreduce_key_value_grads(args)
340345
timers('backward-key-value-all-reduce').stop()

0 commit comments

Comments
 (0)