Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions fla/ops/simple_gla/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv, chunk_fwd_o
from fla.ops.common.fused_chunk import fused_chunk_fwd
from fla.ops.utils import chunk_local_cumsum
from fla.ops.utils.index import prepare_split_cu_seqlens
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard


Expand All @@ -23,7 +25,8 @@ def chunk_simple_gla_fwd(
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64
chunk_size: int = 64,
split_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
h, ht = chunk_fwd_h(
k=k,
Expand All @@ -36,19 +39,34 @@ def chunk_simple_gla_fwd(
output_final_state=output_final_state,
states_in_fp32=False,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size
)
o = chunk_fwd_o(
q=q,
k=k,
v=v,
g=g,
g_gamma=g_gamma,
h=h,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size
chunk_size=chunk_size,
split_size=split_size
)
if split_size is not None:
cu_seqlens = prepare_split_cu_seqlens(*q.shape[:2], split_size, cu_seqlens, device=q.device)
o, _ = fused_chunk_fwd(
q=q,
k=k,
v=v,
g=g,
g_gamma=g_gamma,
scale=scale,
initial_state=h,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size
)
else:
o = chunk_fwd_o(
q=q,
k=k,
v=v,
g=g,
g_gamma=g_gamma,
h=h,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size
)
Comment on lines +45 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential backward compatibility issue with gradient computation.

The conditional logic correctly implements the fused vs. original forward paths. However, there's a potential issue: the backward method in ChunkSimpleGLAFunction (lines 186-206) only uses chunk_simple_gla_bwd, which doesn't account for the fused forward path. This mismatch could cause gradient computation errors when split_size is provided.

Consider one of these solutions:

  1. Update the backward method to handle both forward paths
  2. Ensure chunk_simple_gla_bwd is compatible with both forward implementations
  3. Add a flag to track which forward path was used
#!/bin/bash
# Description: Check if there's a corresponding fused backward implementation
# Expected: Look for fused_chunk_bwd or similar backward functions

rg -A 5 "fused.*bwd|fused.*backward" --type py
🤖 Prompt for AI Agents
In fla/ops/simple_gla/chunk.py around lines 45 to 69, the forward method
conditionally calls either fused_chunk_fwd or chunk_fwd_o based on split_size,
but the backward method only calls chunk_simple_gla_bwd, which does not handle
the fused forward path. To fix this, modify the backward method to detect which
forward path was used (e.g., by adding a flag or checking inputs) and call the
corresponding backward function for fused_chunk_fwd or chunk_fwd_o accordingly,
ensuring correct gradient computation for both cases.

return o, ht


Expand Down Expand Up @@ -153,7 +171,8 @@ def forward(
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size
chunk_size=chunk_size,
split_size=128
)
ctx.save_for_backward(q, k, v, g, g_gamma, initial_state)
ctx.chunk_size = chunk_size
Expand Down
Loading