2222
2323from megatron import get_timers , get_args , get_global_memory_buffer
2424from megatron import mpu
25+ from megatron .utils import print_rank_0
2526from .module import MegatronModule
2627from megatron .model .enums import AttnMaskType , ModelType , LayerType , AttnType , PositionEmbeddingType
2728from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
3839torch ._C ._jit_override_can_fuse_on_cpu (True )
3940torch ._C ._jit_override_can_fuse_on_gpu (True )
4041
42+ try :
43+ from einops import rearrange
44+ except ImportError :
45+ rearrange = None
46+
47+ try :
48+ from flash_attn .flash_attn_interface import flash_attn_unpadded_func
49+ except ImportError :
50+ flash_attn_unpadded_func = None
51+
4152
4253""" We use the following notation throughout this file:
4354 h: hidden size
@@ -459,6 +470,48 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
459470 return context_layer
460471
461472
473+ class FlashSelfAttention (torch .nn .Module ):
474+ """Implement the scaled dot product attention with softmax.
475+ Arguments
476+ ---------
477+ softmax_scale: The temperature to use for the softmax attention.
478+ (default: 1/sqrt(d_keys) where d_keys is computed at
479+ runtime)
480+ attention_dropout: The dropout rate to apply to the attention
481+ (default: 0.0)
482+ """
483+ def __init__ (self , causal = False , softmax_scale = None , attention_dropout = 0.0 ,
484+ device = None , dtype = None ):
485+ super ().__init__ ()
486+ assert flash_attn_unpadded_func is not None , ('Please install FlashAttention first, '
487+ 'e.g., with pip install flash-attn' )
488+ assert rearrange is not None , 'Please install einops first, e.g., with pip install einops'
489+ self .causal = causal
490+ self .softmax_scale = softmax_scale
491+ self .dropout_p = attention_dropout
492+
493+ def forward (self , q , k , v ):
494+ """Implements the multihead softmax attention.
495+ Arguments
496+ ---------
497+ q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
498+ """
499+ assert q .dtype in [torch .float16 , torch .bfloat16 ]
500+ assert q .is_cuda
501+ batch_size , seqlen = q .shape [0 ], q .shape [1 ]
502+ q , k , v = [rearrange (x , 'b s ... -> (b s) ...' ) for x in [q , k , v ]]
503+ max_s = seqlen
504+ cu_seqlens = torch .arange (0 , (batch_size + 1 ) * seqlen , step = seqlen , dtype = torch .int32 ,
505+ device = q .device )
506+ output = flash_attn_unpadded_func (
507+ q , k , v , cu_seqlens , cu_seqlens , max_s , max_s ,
508+ self .dropout_p if self .training else 0.0 ,
509+ softmax_scale = self .softmax_scale , causal = self .causal
510+ )
511+ output = rearrange (output , '(b s) ... -> b s ...' , b = batch_size )
512+ return output
513+
514+
462515class ParallelAttention (MegatronModule ):
463516 """Parallel self-attention layer abstract class.
464517
@@ -477,6 +530,9 @@ def __init__(self, init_method,
477530 self .attn_mask_type = attn_mask_type
478531 self .params_dtype = args .params_dtype
479532 self .attention_head_type = args .attention_head_type
533+ self .sequence_parallel = args .sequence_parallel
534+
535+ self .use_flash_attn = args .use_flash_attn
480536
481537 projection_size = args .kv_channels * args .num_attention_heads
482538
@@ -533,6 +589,26 @@ def __init__(self, init_method,
533589 else :
534590 self .core_attention = MultiQueryCoreAttention (self .layer_number , self .attn_mask_type )
535591 self .checkpoint_core_attention = args .recompute_granularity == 'selective'
592+
593+ if self .use_flash_attn :
594+ if flash_attn_unpadded_func is None :
595+ raise ImportError ('FlashAttention is not installed, please install with '
596+ 'pip install flash-attn' )
597+ assert attention_type == AttnType .self_attn , ('FlashAttention code path only supports '
598+ 'self-attention for now' )
599+ assert self .attn_mask_type == AttnMaskType .causal , ('FlashAttention code path only '
600+ 'supports causal mask for now' )
601+ assert args .position_embedding_type != PositionEmbeddingType .alibi , \
602+ ('FlashAttention does not support alibi positional embeddings yet' )
603+ if rearrange is None :
604+ raise ImportError ('einops is not installed, please install with pip install einops' )
605+
606+ if self .checkpoint_core_attention :
607+ print_rank_0 (" Warning, using selective recomputation with flash-attn: this is already handled in the "
608+ "flash-attn library and has no effect." )
609+ self .core_attention_flash = FlashSelfAttention (
610+ causal = True , attention_dropout = args .attention_dropout
611+ )
536612
537613 # Output.
538614 self .dense = mpu .RowParallelLinear (
@@ -699,13 +775,30 @@ def forward(self, hidden_states, attention_mask,
699775 # ==================================
700776 # core attention computation
701777 # ==================================
778+ if self .use_flash_attn :
779+ if self .attention_head_type == "multiquery" :
780+ sq , b , np , hn = query_layer .size ()
781+ # Expand kv to be compatible with flash-attn implementation
782+ # [sq, b, 1, hn] -> [sq, b, np, hn]
783+ key_layer = key_layer .expand ((sq , b , np , hn ))
784+ value_layer = value_layer .expand ((sq , b , np , hn ))
785+ q , k , v = [rearrange (x , 's b ... -> b s ...' ).contiguous ()
786+ for x in (query_layer , key_layer , value_layer )]
787+ if self .sequence_parallel :
788+ context_layer = self .core_attention_flash (q , k , v )
789+ else :
790+ with mpu .get_cuda_rng_tracker ().fork ():
791+ context_layer = self .core_attention_flash (q , k , v )
792+ context_layer = rearrange (context_layer , 'b s h d -> s b (h d)' ).contiguous ()
702793
703- if self .checkpoint_core_attention :
704- context_layer = self ._checkpointed_attention_forward (
705- query_layer , key_layer , value_layer , attention_mask , alibi )
706794 else :
707- context_layer = self .core_attention (
708- query_layer , key_layer , value_layer , attention_mask , alibi )
795+ if self .checkpoint_core_attention :
796+ context_layer = self ._checkpointed_attention_forward (
797+ query_layer , key_layer , value_layer , attention_mask , alibi )
798+ else :
799+ context_layer = self .core_attention (
800+ query_layer , key_layer , value_layer , attention_mask , alibi )
801+
709802
710803 # =================
711804 # Output. [sq, b, h]
0 commit comments