Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions fla/modules/l2norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ def l2norm_bwd_kernel1(
key=['D', 'NB'],
**autotune_cache_kwargs,
)
@triton.jit
@triton.jit(do_not_specialize=['T'])
def l2norm_fwd_kernel(
x,
y,
rstd,
eps,
T: tl.constexpr,
T,
D: tl.constexpr,
BD: tl.constexpr,
NB: tl.constexpr,
Expand Down Expand Up @@ -117,14 +117,14 @@ def l2norm_fwd_kernel(
key=['D', 'NB'],
**autotune_cache_kwargs,
)
@triton.jit
@triton.jit(do_not_specialize=['T'])
def l2norm_bwd_kernel(
y,
rstd,
dy,
dx,
eps,
T: tl.constexpr,
T,
D: tl.constexpr,
BD: tl.constexpr,
NB: tl.constexpr,
Expand Down
Loading