Skip to content

Commit 1c403c3

Browse files
authored
Fix: Correct K/V dimension mismatch in path_attn bwd kernels"changing K/BK to V/BV for v and dv operations (#633)
* changing K/BK to V/BV for v and dv operations * correcting the v tensor's loading layout, block size, and applying the necessary transpose in the dot product
1 parent db55f88 commit 1c403c3

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

fla/ops/path_attn/parallel_path_bwd_inter_dkv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def parallel_path_bwd_dkv_kernel(
7777
# load query
7878
p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
7979
b_k = tl.load(p_k, boundary_check=(0, 1))
80-
p_v = tl.make_block_ptr(v, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
80+
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, 0), (BT, BV), (1, 0))
8181
b_v = tl.load(p_v, boundary_check=(0, 1))
8282

8383
if USE_GATE:
@@ -90,7 +90,7 @@ def parallel_path_bwd_dkv_kernel(
9090
b_dg_cumsum_k = None
9191

9292
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
93-
b_dv = tl.zeros([BT, BK], dtype=tl.float32)
93+
b_dv = tl.zeros([BT, BV], dtype=tl.float32)
9494

9595
last_chunk_start = tl.floor(i_t*BT / S).to(tl.int32) * S
9696
idx_j = (tl.floor(i_t * BT / S).to(tl.int32) + 1).to(tl.int32)
@@ -127,7 +127,7 @@ def parallel_path_bwd_dkv_kernel(
127127
tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1))
128128
mask = i_t * BT + tl.arange(0, BT) < T
129129
tl.atomic_add(
130-
dv + (i_t * BT + tl.arange(0, BT))[:, None] * HQ * K + tl.arange(0, BK)[None, :],
130+
dv + (i_t * BT + tl.arange(0, BT))[:, None] * HQ * V + tl.arange(0, BV)[None, :],
131131
b_dv,
132132
mask=mask[:, None],
133133
sem='relaxed',

fla/ops/path_attn/parallel_path_bwd_inter_dqh.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ def parallel_path_bwd_dq_kernel(
118118
b_A = b_A + b_g_cumsum_q[:, None] - b_g_cumsum_k[None, :]
119119
b_A = exp2(b_A * sm_scale - b_l[:, None])
120120
b_A = tl.where(m_t[:, None], b_A, 0)
121-
p_v = tl.make_block_ptr(v, (V, T), (1, V*H), (0, offset), (BK, BS), (0, 1))
121+
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (offset, 0), (BS, BV), (1, 0))
122122
b_v = tl.load(p_v, boundary_check=(0, 1))
123-
b_dp = tl.dot(b_do, b_v.to(b_do.dtype))
123+
b_dp = tl.dot(b_do, tl.trans(b_v).to(b_do.dtype))
124124
b_dA = (b_dp - b_delta[:, None]) * b_A * scale
125125
b_dq += tl.dot(b_dA.to(b_k.dtype), b_k)
126126
if USE_GATE:

0 commit comments

Comments
 (0)