if use_gate_in_kernel:
g_org = g
g = kda_gate_fwd(
g=g_org,
A_log=A_log,
dt_bias=dt_bias,
)
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices)
else:
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices)