Commit 5fdb0b9
committed
Adds backward preprocessing kernels for FlashAttention
Implements specialized kernels for preprocessing backward pass computations in FlashAttention when parallelizing across sequence dimensions.
Provides functions to compute dot products between gradient outputs and outputs, clear accumulator buffers, and convert between precision formats during the backward pass.
Enables efficient memory management and computation distribution for large sequence length scenarios by handling accumulator initialization and type conversions separately from main backward kernels.1 parent 3ad89fc commit 5fdb0b9
1 file changed
+376
-0
lines changed
0 commit comments