Skip to content

Commit e0b644b

Browse files
authored
Merge pull request #41 from bigcode-project/flash-attention
Add flash-attn
2 parents 8b38744 + 0ff5746 commit e0b644b

File tree

3 files changed

+113
-5
lines changed

3 files changed

+113
-5
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,18 @@ Theoretical memory savings vary depending on the combination of the model's para
333333
| bf16 param, fp32 grads | 18 | 6 + 12/d |
334334
| fp32 param, fp32 grads | 16 | 8 + 8/d |
335335

336+
## FlashAttention
337+
338+
Usage: `--use-flash-attn`. Support attention head dimensions at most 128.
339+
340+
[FlashAttention](https://github.com/HazyResearch/flash-attention) is a fast and
341+
memory-efficient algorithm to compute exact attention. It speeds up model
342+
training and reduces memory requirement.
343+
344+
To install FlashAttention:
345+
```sh
346+
pip install flash-attn
347+
```
336348

337349
## GPT-3 Example
338350

megatron/arguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,9 @@ def _add_training_args(parser):
648648
group.add_argument('--no-bias-dropout-fusion', action='store_false',
649649
help='Disable bias and dropout fusion.',
650650
dest='bias_dropout_fusion')
651+
group.add_argument('--use-flash-attn', action='store_true',
652+
help='use FlashAttention implementation of attention. '
653+
'https://arxiv.org/abs/2205.14135')
651654
group.add_argument('--optimizer', type=str, default='adam',
652655
choices=['adam', 'sgd'],
653656
help='Optimizer function')

megatron/model/transformer.py

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from megatron import get_timers, get_args, get_global_memory_buffer
2424
from megatron import mpu
25+
from megatron.utils import print_rank_0
2526
from .module import MegatronModule
2627
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType, PositionEmbeddingType
2728
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
@@ -38,6 +39,16 @@
3839
torch._C._jit_override_can_fuse_on_cpu(True)
3940
torch._C._jit_override_can_fuse_on_gpu(True)
4041

42+
try:
43+
from einops import rearrange
44+
except ImportError:
45+
rearrange = None
46+
47+
try:
48+
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
49+
except ImportError:
50+
flash_attn_unpadded_func = None
51+
4152

4253
""" We use the following notation throughout this file:
4354
h: hidden size
@@ -459,6 +470,48 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
459470
return context_layer
460471

461472

473+
class FlashSelfAttention(torch.nn.Module):
474+
"""Implement the scaled dot product attention with softmax.
475+
Arguments
476+
---------
477+
softmax_scale: The temperature to use for the softmax attention.
478+
(default: 1/sqrt(d_keys) where d_keys is computed at
479+
runtime)
480+
attention_dropout: The dropout rate to apply to the attention
481+
(default: 0.0)
482+
"""
483+
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
484+
device=None, dtype=None):
485+
super().__init__()
486+
assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
487+
'e.g., with pip install flash-attn')
488+
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
489+
self.causal = causal
490+
self.softmax_scale = softmax_scale
491+
self.dropout_p = attention_dropout
492+
493+
def forward(self, q, k, v):
494+
"""Implements the multihead softmax attention.
495+
Arguments
496+
---------
497+
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
498+
"""
499+
assert q.dtype in [torch.float16, torch.bfloat16]
500+
assert q.is_cuda
501+
batch_size, seqlen = q.shape[0], q.shape[1]
502+
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
503+
max_s = seqlen
504+
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
505+
device=q.device)
506+
output = flash_attn_unpadded_func(
507+
q, k, v, cu_seqlens, cu_seqlens, max_s, max_s,
508+
self.dropout_p if self.training else 0.0,
509+
softmax_scale=self.softmax_scale, causal=self.causal
510+
)
511+
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
512+
return output
513+
514+
462515
class ParallelAttention(MegatronModule):
463516
"""Parallel self-attention layer abstract class.
464517
@@ -477,6 +530,9 @@ def __init__(self, init_method,
477530
self.attn_mask_type = attn_mask_type
478531
self.params_dtype = args.params_dtype
479532
self.attention_head_type = args.attention_head_type
533+
self.sequence_parallel = args.sequence_parallel
534+
535+
self.use_flash_attn = args.use_flash_attn
480536

481537
projection_size = args.kv_channels * args.num_attention_heads
482538

@@ -533,6 +589,26 @@ def __init__(self, init_method,
533589
else:
534590
self.core_attention = MultiQueryCoreAttention(self.layer_number, self.attn_mask_type)
535591
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
592+
593+
if self.use_flash_attn:
594+
if flash_attn_unpadded_func is None:
595+
raise ImportError('FlashAttention is not installed, please install with '
596+
'pip install flash-attn')
597+
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
598+
'self-attention for now')
599+
assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
600+
'supports causal mask for now')
601+
assert args.position_embedding_type != PositionEmbeddingType.alibi, \
602+
('FlashAttention does not support alibi positional embeddings yet')
603+
if rearrange is None:
604+
raise ImportError('einops is not installed, please install with pip install einops')
605+
606+
if self.checkpoint_core_attention:
607+
print_rank_0(" Warning, using selective recomputation with flash-attn: this is already handled in the "
608+
"flash-attn library and has no effect.")
609+
self.core_attention_flash = FlashSelfAttention(
610+
causal=True, attention_dropout=args.attention_dropout
611+
)
536612

537613
# Output.
538614
self.dense = mpu.RowParallelLinear(
@@ -699,13 +775,30 @@ def forward(self, hidden_states, attention_mask,
699775
# ==================================
700776
# core attention computation
701777
# ==================================
778+
if self.use_flash_attn:
779+
if self.attention_head_type == "multiquery":
780+
sq, b, np, hn = query_layer.size()
781+
# Expand kv to be compatible with flash-attn implementation
782+
# [sq, b, 1, hn] -> [sq, b, np, hn]
783+
key_layer = key_layer.expand((sq, b, np, hn))
784+
value_layer = value_layer.expand((sq, b, np, hn))
785+
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
786+
for x in (query_layer, key_layer, value_layer)]
787+
if self.sequence_parallel:
788+
context_layer = self.core_attention_flash(q, k, v)
789+
else:
790+
with mpu.get_cuda_rng_tracker().fork():
791+
context_layer = self.core_attention_flash(q, k, v)
792+
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
702793

703-
if self.checkpoint_core_attention:
704-
context_layer = self._checkpointed_attention_forward(
705-
query_layer, key_layer, value_layer, attention_mask, alibi)
706794
else:
707-
context_layer = self.core_attention(
708-
query_layer, key_layer, value_layer, attention_mask, alibi)
795+
if self.checkpoint_core_attention:
796+
context_layer = self._checkpointed_attention_forward(
797+
query_layer, key_layer, value_layer, attention_mask, alibi)
798+
else:
799+
context_layer = self.core_attention(
800+
query_layer, key_layer, value_layer, attention_mask, alibi)
801+
709802

710803
# =================
711804
# Output. [sq, b, h]

0 commit comments

Comments
 (0)