Skip to content

Commit 5fdb0b9

Browse files
committed
Adds backward preprocessing kernels for FlashAttention
Implements specialized kernels for preprocessing backward pass computations in FlashAttention when parallelizing across sequence dimensions. Provides functions to compute dot products between gradient outputs and outputs, clear accumulator buffers, and convert between precision formats during the backward pass. Enables efficient memory management and computation distribution for large sequence length scenarios by handling accumulator initialization and type conversions separately from main backward kernels.
1 parent 3ad89fc commit 5fdb0b9

File tree

1 file changed

+376
-0
lines changed

1 file changed

+376
-0
lines changed

0 commit comments

Comments
 (0)