Skip to content

Commit 43adc40

Browse files
committed
Enables compilation for flex attention forward
Activates the compile flag to improve performance through kernel optimization during flex attention computation.
1 parent 6d60c3a commit 43adc40

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

flash_dmattn/flash_dmattn_flex.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
3737
Q_LEN=query.shape[2],
3838
KV_LEN=key.shape[2],
3939
device=query.device,
40-
_compile=False,
40+
_compile=True,
4141
)
4242

4343
kernel_options = {

0 commit comments

Comments
 (0)