@@ -611,12 +611,22 @@ def forward(self, hidden_states, attention_mask,
611611 value_layer ) = mpu .split_tensor_along_last_dim (mixed_x_layer , 3 )
612612 elif self .attention_type == AttnType .self_attn and self .attention_head_type == 'multiquery' :
613613 kv_input = hidden_states
614- if get_args ().sequence_parallel :
615- # The linear layer doesn't gather the sequence-parallel.
616- kv_input = mpu .gather_from_sequence_parallel_region (kv_input , tensor_parallel_output_grad = False )
617614 # Attention heads [sq, b, h] --> [sq, b, (2 * hn)]
618615 mixed_kv_layer = self .key_value (kv_input )
619616
617+ # Reduce the KV gradients in the tensor-parallel direction.
618+ # This is different from multi-head attention which reduces the KV input,
619+ # because the sum over attn heads happens in the attn weight gradient instead of the KV layer:
620+ # A [b, n * sq, sk] = Q [b, n * sq, hn] x K^T [b, hn, sk]
621+ # G_K [b, sk, hn] = G_A [b, sk, n * sq] x Q [b, n * sq, hn]
622+ # = sum_p (G_Ap [b, sk, np * sq] x Q_p [b, np * sq, hn])
623+ if get_args ().sequence_parallel :
624+ # We switch to the tensor parallel regime here instead of at the KV input
625+ # so that the KV layer is done in parallel instead of just duplicated.
626+ mixed_kv_layer = mpu .gather_from_sequence_parallel_region (mixed_kv_layer , tensor_parallel_output_grad = True )
627+ else :
628+ mixed_kv_layer = mpu .copy_to_tensor_model_parallel_region (mixed_kv_layer )
629+
620630 # [sq, b, (2 * hn)] --> [sq, b, np (expanded), 2 * hn]
621631 # new_tensor_shape = mixed_kv_layer.size()[:-1] + \
622632 # (self.num_attention_heads_per_partition,
0 commit comments