Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions fla/ops/delta_rule/wy_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,46 @@ def prepare_wy_repr_bwd(
bwd_prepare_wy_repr = prepare_wy_repr_bwd

fwd_recompute_w_u = recompute_w_u_fwd


def fwd_prepare_T(
k: torch.Tensor,
beta: torch.Tensor,
chunk_size: int,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
"""
Prepare the transformation matrix T (A) for delta rule computation.

This function computes the matrix A = (I - tril(beta * K * K^T))^{-1}
which is used in the parallel delta rule algorithm.

Args:
k: Key tensor of shape [B, H, T, K] (head-first format)
beta: Beta weights of shape [B, H, T] (head-first format)
chunk_size: Size of chunks for processing
cu_seqlens: Optional cumulative sequence lengths for variable-length sequences

Returns:
A: Transformation matrix of shape [B, H, T, chunk_size]
"""
# Convert from head-first [B, H, T, K] to seq-first [B, T, H, K]
k_seq_first = k.transpose(1, 2).contiguous()
beta_seq_first = beta.transpose(1, 2).contiguous()

A = chunk_scaled_dot_kkt_fwd(
k=k_seq_first,
beta=beta_seq_first,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
output_dtype=torch.float32,
)
A = solve_tril(
A=A,
cu_seqlens=cu_seqlens,
output_dtype=k.dtype
)

# Convert back from [B, T, H, chunk_size] to [B, H, T, chunk_size]
A = A.transpose(1, 2).contiguous()
return A
103 changes: 103 additions & 0 deletions tests/ops/test_parallel_delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-

import pytest
import torch
import torch.nn.functional as F

from fla.ops.delta_rule.parallel import naive_delta_rule_parallel, parallel_delta_rule
from fla.ops.delta_rule.wy_fast import fwd_prepare_T
from fla.utils import assert_close, device, device_platform

# IMPORTANT NOTE ON TENSOR FORMATS:
# While the documentation for some functions states inputs should be in [B, T, H, K] format,
# the actual implementation expects [B, H, T, K] format (head-first).
# All tests in this file use the head-first format to match the actual implementation.

# NOTE ON TEST IMPLEMENTATION:
# We currently skip comparing parallel_delta_rule against naive_delta_rule_parallel
# because the naive implementation produces NaN values. This will be addressed in a
# future update. For now, we only verify that parallel_delta_rule runs without errors
# and produces outputs with the expected shapes.


@pytest.mark.parametrize(
('B', 'H', 'T', 'K', 'dtype'),
[
pytest.param(*test, id="B{}-H{}-T{}-K{}-{}".format(*test))
for test in [
(1, 2, 128, 64, torch.float16),
(2, 4, 128, 32, torch.float16),
]
]
)
@pytest.mark.skipif(
device_platform == 'intel',
reason='Intel Triton Failure'
)
def test_parallel_delta_rule(
B: int,
H: int,
T: int,
K: int,
dtype: torch.dtype,
):
"""Test parallel_delta_rule against naive implementation."""
torch.manual_seed(42)

# Generate test data
q = torch.randn(B, H, T, K, dtype=dtype, device=device)
k = F.normalize(torch.randn(B, H, T, K, dtype=dtype, device=device), p=2, dim=-1).to(dtype)
v = torch.randn(B, H, T, K, dtype=dtype, device=device)
beta = torch.randn(B, H, T, dtype=dtype, device=device).sigmoid()
scale = 1.0 / (K ** 0.5)

# Define whether to output attention matrices
output_attentions = True

# Test forward pass
o_parallel, attn_parallel = parallel_delta_rule(
q=q.clone(),
k=k.clone(),
v=v.clone(),
beta=beta.clone(),
scale=scale,
output_attentions=output_attentions
)

# Output should have the same shape as input v
assert o_parallel.shape == v.shape, f"Expected shape {v.shape}, got {o_parallel.shape}"

# Check that attention matrix is produced if requested
if output_attentions:
assert attn_parallel is not None
assert attn_parallel.shape == (B, H, T, T), f"Expected shape {(B, H, T, T)}, got {attn_parallel.shape}"
else:
assert attn_parallel is None

o_naive, attn_naive = naive_delta_rule_parallel(q.clone(), k.clone(), v.clone(), beta.clone())

assert_close(' o', o_parallel, o_naive, 0.01)
assert_close('attn', attn_naive, attn_parallel, 0.01)


@pytest.mark.skipif(
device_platform == 'intel',
reason='Intel Triton Failure'
)
def test_fwd_prepare_T():
"""Test that fwd_prepare_T can be imported and runs without error."""
torch.manual_seed(42)

# Using head-first format [B, H, T, K] to match other functions
B, H, T, K = 2, 4, 128, 64
k = torch.randn(B, H, T, K, device=device)
beta = torch.randn(B, H, T, device=device).sigmoid()
chunk_size = 32

# Test the function runs without error
A = fwd_prepare_T(k, beta, chunk_size)

# Check output shape
# After our fix, fwd_prepare_T returns [B, H, T, chunk_size] (head-first format)
expected_shape = (B, H, T, chunk_size)
assert A.shape == expected_shape, f"Expected shape {expected_shape}, got {A.shape}"
Loading