Commit 974451e
committed
Adds Triton and Flex Attention backward pass implementations
Extends backward equivalence testing to include Triton and Flex Attention implementations alongside existing Python and CUDA versions.
Updates function signatures to return attention bias gradients and removes softmax log-sum-exp calculations for consistency across implementations.
Fixes attention bias application in Python reference implementation and improves gradient retention handling for proper backward pass computation.
Enhances test configurations with comprehensive parameter combinations and better error handling for missing implementations.1 parent 5d3bd41 commit 974451e
1 file changed
+639
-35
lines changed
0 commit comments