From 8c63c1961f6f9dc1bc08a0fa8f0c396aadea8b3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=83=AC=E6=BA=90?= Date: Fri, 28 Nov 2025 11:56:12 +0800 Subject: [PATCH 1/3] remove unnecessary tl.constexpr signature for l2_norm kernel --- fla/modules/l2norm.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/fla/modules/l2norm.py b/fla/modules/l2norm.py index bd2662de0..bf9c89bac 100644 --- a/fla/modules/l2norm.py +++ b/fla/modules/l2norm.py @@ -80,7 +80,7 @@ def l2norm_bwd_kernel1( for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST ], - key=['D', 'NB'], + key=['D'], **autotune_cache_kwargs, ) @triton.jit @@ -89,10 +89,9 @@ def l2norm_fwd_kernel( y, rstd, eps, - T: tl.constexpr, + T, D: tl.constexpr, BD: tl.constexpr, - NB: tl.constexpr, BT: tl.constexpr, ): i_t = tl.program_id(0) @@ -114,7 +113,7 @@ def l2norm_fwd_kernel( for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST ], - key=['D', 'NB'], + key=['D'], **autotune_cache_kwargs, ) @triton.jit @@ -124,10 +123,9 @@ def l2norm_bwd_kernel( dy, dx, eps, - T: tl.constexpr, + T, D: tl.constexpr, BD: tl.constexpr, - NB: tl.constexpr, BT: tl.constexpr, ): i_t = tl.program_id(0) @@ -165,7 +163,6 @@ def l2norm_fwd( rstd = torch.empty((T,), dtype=torch.float32, device=x.device) if D <= 512: - NB = triton.cdiv(T, 2048) def grid(meta): return (triton.cdiv(T, meta['BT']), ) l2norm_fwd_kernel[grid]( x=x, @@ -175,7 +172,6 @@ def grid(meta): return (triton.cdiv(T, meta['BT']), ) T=T, D=D, BD=BD, - NB=NB, ) else: l2norm_fwd_kernel1[(T,)]( @@ -209,7 +205,6 @@ def l2norm_bwd( raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") if D <= 512: - NB = triton.cdiv(T, 2048) def grid(meta): return (triton.cdiv(T, meta['BT']), ) l2norm_bwd_kernel[grid]( y=y, @@ -220,7 +215,6 @@ def grid(meta): return (triton.cdiv(T, meta['BT']), ) T=T, D=D, BD=BD, - NB=NB, ) else: l2norm_bwd_kernel1[(T,)]( From 1165df27f495f5c377afb3e7999d542dad975b04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=83=AC=E6=BA=90?= Date: Mon, 1 Dec 2025 13:42:07 +0800 Subject: [PATCH 2/3] restore NB as autotune key --- fla/modules/l2norm.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/fla/modules/l2norm.py b/fla/modules/l2norm.py index bf9c89bac..11fc3b83e 100644 --- a/fla/modules/l2norm.py +++ b/fla/modules/l2norm.py @@ -80,7 +80,7 @@ def l2norm_bwd_kernel1( for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST ], - key=['D'], + key=['D', 'NB'], **autotune_cache_kwargs, ) @triton.jit @@ -92,6 +92,7 @@ def l2norm_fwd_kernel( T, D: tl.constexpr, BD: tl.constexpr, + NB: tl.constexpr, BT: tl.constexpr, ): i_t = tl.program_id(0) @@ -113,7 +114,7 @@ def l2norm_fwd_kernel( for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST ], - key=['D'], + key=['D', 'NB'], **autotune_cache_kwargs, ) @triton.jit @@ -126,6 +127,7 @@ def l2norm_bwd_kernel( T, D: tl.constexpr, BD: tl.constexpr, + NB: tl.constexpr, BT: tl.constexpr, ): i_t = tl.program_id(0) @@ -163,6 +165,7 @@ def l2norm_fwd( rstd = torch.empty((T,), dtype=torch.float32, device=x.device) if D <= 512: + NB = triton.cdiv(T, 2048) def grid(meta): return (triton.cdiv(T, meta['BT']), ) l2norm_fwd_kernel[grid]( x=x, @@ -172,6 +175,7 @@ def grid(meta): return (triton.cdiv(T, meta['BT']), ) T=T, D=D, BD=BD, + NB=NB, ) else: l2norm_fwd_kernel1[(T,)]( @@ -205,6 +209,7 @@ def l2norm_bwd( raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") if D <= 512: + NB = triton.cdiv(T, 2048) def grid(meta): return (triton.cdiv(T, meta['BT']), ) l2norm_bwd_kernel[grid]( y=y, @@ -215,6 +220,7 @@ def grid(meta): return (triton.cdiv(T, meta['BT']), ) T=T, D=D, BD=BD, + NB=NB, ) else: l2norm_bwd_kernel1[(T,)]( From 0b941793c91eb5d02fc0a13d1610201877e77621 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 1 Dec 2025 14:04:28 +0800 Subject: [PATCH 3/3] Add do_not_specialize option to Triton kernels --- fla/modules/l2norm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fla/modules/l2norm.py b/fla/modules/l2norm.py index 11fc3b83e..3e2a77bfb 100644 --- a/fla/modules/l2norm.py +++ b/fla/modules/l2norm.py @@ -83,7 +83,7 @@ def l2norm_bwd_kernel1( key=['D', 'NB'], **autotune_cache_kwargs, ) -@triton.jit +@triton.jit(do_not_specialize=['T']) def l2norm_fwd_kernel( x, y, @@ -117,7 +117,7 @@ def l2norm_fwd_kernel( key=['D', 'NB'], **autotune_cache_kwargs, ) -@triton.jit +@triton.jit(do_not_specialize=['T']) def l2norm_bwd_kernel( y, rstd,