Skip to content

Commit 3a28e93

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Add a lowering rule for lax.clamp
PiperOrigin-RevId: 842643534
1 parent d9a7388 commit 3a28e93

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2217,6 +2217,14 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
22172217
return res
22182218

22192219

2220+
@register_lowering_rule(lax.clamp_p, mgpu.LoweringSemantics.Lane)
2221+
@register_lowering_rule(lax.clamp_p, mgpu.LoweringSemantics.Warpgroup)
2222+
def _clamp_lowering_rule(ctx: LoweringRuleContext, l, x, u):
2223+
return _lower_fun(
2224+
lambda l, x, u: lax.min(lax.max(x, l), u), multiple_results=False
2225+
)(ctx, l, x, u)
2226+
2227+
22202228
@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Lane)
22212229
@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Warpgroup)
22222230
def _square_lowering_rule(ctx: LoweringRuleContext, x):

tests/pallas/ops_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,29 @@ def kernel(x_ref, o_ref):
10171017
expected = lax.is_finite(x)
10181018
self.assertArraysEqual(out, expected)
10191019

1020+
@parameterized.parameters(jnp.float32, jnp.bfloat16, jnp.int32, jnp.int16)
1021+
def test_clamp(self, dtype):
1022+
if dtype == jnp.int16 and jtu.test_device_matches(["tpu"]):
1023+
self.skipTest("int16 is not supported on TPU")
1024+
1025+
k1, k2, k3 = random.split(jax.random.key(0), num=3)
1026+
if jnp.issubdtype(dtype, jnp.floating):
1027+
lo_ = random.normal(k1, (8, 128), dtype=dtype)
1028+
hi_ = random.normal(k2, (8, 128), dtype=dtype)
1029+
x = random.normal(k3, (8, 128), dtype=dtype)
1030+
else:
1031+
lo_ = random.randint(k1, (8, 128), -100, 100, dtype=dtype)
1032+
hi_ = random.randint(k2, (8, 128), -100, 100, dtype=dtype)
1033+
x = random.randint(k3, (8, 128), -100, 100, dtype=dtype)
1034+
lo = jnp.minimum(lo_, hi_)
1035+
hi = jnp.maximum(lo_, hi_)
1036+
@functools.partial(
1037+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), dtype),
1038+
)
1039+
def kernel(lo_ref, x_ref, hi_ref, o_ref):
1040+
o_ref[...] = lax.clamp(lo_ref[...], x_ref[...], hi_ref[...])
1041+
np.testing.assert_array_equal(kernel(lo, x, hi), lax.clamp(lo, x, hi))
1042+
10201043
@parameterized.named_parameters(
10211044
(dtype.__name__, dtype)
10221045
for dtype in (jnp.float32, jnp.float16, jnp.bfloat16)

0 commit comments

Comments
 (0)