Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions fla/ops/kda/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fla.ops.gla.chunk import chunk_gla_bwd_dA, chunk_gla_fwd_o_gk
from fla.ops.kda.chunk_inter import chunk_kda_bwd_dqkwg
from fla.ops.kda.chunk_intra import chunk_kda_bwd_intra, chunk_kda_fwd_intra
from fla.ops.kda.gate import kda_gate_bwd, kda_gate_fwd
from fla.ops.kda.gate import kda_gate_bwd, kda_gate_chunk_cumsum, kda_gate_fwd
from fla.ops.kda.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd
from fla.ops.utils import chunk_local_cumsum
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
Expand Down Expand Up @@ -209,21 +209,25 @@ def forward(
cu_seqlens: torch.LongTensor | None = None,
chunk_indices: torch.LongTensor | None = None,
):
chunk_size = 64
g_org = None
if use_gate_in_kernel:
g_org = g
g = kda_gate_fwd(
g = kda_gate_chunk_cumsum(
g=g_org,
A_log=A_log,
dt_bias=dt_bias,
chunk_size=chunk_size,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices
)
else:
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices)
q_rstd, k_rstd = None, None
if use_qk_l2norm_in_kernel:
q, q_rstd = l2norm_fwd(q)
k, k_rstd = l2norm_fwd(k)

chunk_size = 64
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices)
o, Aqk, Akk, final_state = chunk_kda_fwd(
q=q,
k=k,
Expand Down
129 changes: 128 additions & 1 deletion fla/ops/kda/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import triton
import triton.language as tl

from fla.ops.utils.index import prepare_chunk_indices
from fla.ops.utils.softplus import softplus
from fla.utils import IS_AMD, autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, input_guard
from fla.utils import IS_AMD, autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, check_shared_mem, input_guard

BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
BT_LIST_AUTOTUNE = [32, 64, 128]
NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if IS_AMD else [4, 8, 16, 32]

Expand Down Expand Up @@ -288,3 +290,128 @@ def fused_kda_gate(
Output tensor of shape `[..., H, K]`.
"""
return KDAGateFunction.apply(g, A_log, dt_bias, output_dtype)



@triton.heuristics({
"HAS_BIAS": lambda args: args["dt_bias"] is not None,
'HAS_SCALE': lambda args: args['scale'] is not None,
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.autotune(
configs=[
triton.Config({'BS': BS}, num_warps=num_warps)
for BS in BS_LIST
for num_warps in [2, 4, 8]
],
key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE'],
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def kda_gate_chunk_cumsum_vector_kernel(
s,
A_log,
dt_bias,
o,
scale,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr,
REVERSE: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_SCALE: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr,
):
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T

if HEAD_FIRST:
p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
else:
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)

# Apply dt_bias if exists
if HAS_BIAS:
p_b = tl.make_block_ptr(dt_bias + i_h * S, (S,), (1,), (i_s * BS,), (BS,), (0,))
b_bias = tl.load(p_b, boundary_check=(0,)).to(tl.float32)
b_s = b_s + b_bias[None, :]

# Apply gate: -exp(A_log) * softplus(g + bias)
b_A = tl.load(A_log + i_h).to(tl.float32)
b_gate = -tl.exp(b_A) * softplus(b_s)

# Apply chunk local cumsum
if REVERSE:
b_o = tl.cumsum(b_gate, axis=0, reverse=True)
else:
b_o = tl.cumsum(b_gate, axis=0)

if HAS_SCALE:
b_o *= scale
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))



@input_guard
def kda_gate_chunk_cumsum(
g: torch.Tensor,
A_log: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
dt_bias: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
chunk_indices: torch.LongTensor | None = None,
**kwargs,
) -> torch.Tensor:
if cu_seqlens is not None:
assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
assert len(g.shape) == 4

if head_first:
B, H, T, S = g.shape
else:
B, T, H, S = g.shape
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion to check if chunk_size is a power of two is a bit obscure and might be hard to understand for future maintainers. A more conventional and clearer way to perform this check is by using bitwise operations.

Suggested change
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
assert (chunk_size > 0) and (chunk_size & (chunk_size - 1) == 0), "chunk_size must be a power of 2"


g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
kda_gate_chunk_cumsum_vector_kernel[grid](
s=g_org,
A_log=A_log,
dt_bias=dt_bias,
o=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
B=B,
H=H,
S=S,
BT=BT,
HEAD_FIRST=head_first,
REVERSE=reverse,
)
return g
Loading