diff --git a/fla/ops/simple_gla/chunk.py b/fla/ops/simple_gla/chunk.py index b9f9ed97c..8d5feb82d 100644 --- a/fla/ops/simple_gla/chunk.py +++ b/fla/ops/simple_gla/chunk.py @@ -9,7 +9,9 @@ from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv, chunk_fwd_o +from fla.ops.common.fused_chunk import fused_chunk_fwd from fla.ops.utils import chunk_local_cumsum +from fla.ops.utils.index import prepare_split_cu_seqlens from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard @@ -23,7 +25,8 @@ def chunk_simple_gla_fwd( initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64 + chunk_size: int = 64, + split_size: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: h, ht = chunk_fwd_h( k=k, @@ -36,19 +39,34 @@ def chunk_simple_gla_fwd( output_final_state=output_final_state, states_in_fp32=False, cu_seqlens=cu_seqlens, - chunk_size=chunk_size - ) - o = chunk_fwd_o( - q=q, - k=k, - v=v, - g=g, - g_gamma=g_gamma, - h=h, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size + chunk_size=chunk_size, + split_size=split_size ) + if split_size is not None: + cu_seqlens = prepare_split_cu_seqlens(*q.shape[:2], split_size, cu_seqlens, device=q.device) + o, _ = fused_chunk_fwd( + q=q, + k=k, + v=v, + g=g, + g_gamma=g_gamma, + scale=scale, + initial_state=h, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + else: + o = chunk_fwd_o( + q=q, + k=k, + v=v, + g=g, + g_gamma=g_gamma, + h=h, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) return o, ht @@ -153,7 +171,8 @@ def forward( initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, - chunk_size=chunk_size + chunk_size=chunk_size, + split_size=128 ) ctx.save_for_backward(q, k, v, g, g_gamma, initial_state) ctx.chunk_size = chunk_size