Skip to content

Commit 128b4e9

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Add support for jnp.sin
PiperOrigin-RevId: 841730048
1 parent 052ce57 commit 128b4e9

File tree

4 files changed

+17
-1
lines changed

4 files changed

+17
-1
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2302,6 +2302,19 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
23022302
)
23032303
return math_dialect.exp2(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
23042304

2305+
@register_lowering_rule(lax.sin_p, mgpu.LoweringSemantics.Lane)
2306+
@register_lowering_rule(lax.sin_p, mgpu.LoweringSemantics.Warpgroup)
2307+
def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
2308+
if accuracy is not None:
2309+
raise NotImplementedError("Not implemented: accuracy")
2310+
[x_aval] = ctx.avals_in
2311+
if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
2312+
return _ensure_fa(x, x_aval.dtype).sin(approx=ctx.module_ctx.approx_math)
2313+
fastmath = (
2314+
arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
2315+
)
2316+
return math_dialect.sin(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
2317+
23052318

23062319
@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Lane)
23072320
@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Warpgroup)

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,7 @@ def _unary_op_lowering_rule(
10991099
(mlir_math.RsqrtOp, fa.FragmentedArray.rsqrt, None),
11001100
(mlir_math.ExpOp, fa.FragmentedArray.exp, None),
11011101
(mlir_math.Exp2Op, fa.FragmentedArray.exp2, None),
1102+
(mlir_math.SinOp, fa.FragmentedArray.sin, None),
11021103
(mlir_math.LogOp, fa.FragmentedArray.log, None),
11031104
(mlir_math.TanhOp, fa.FragmentedArray.tanh, None),
11041105
]:

jax/experimental/mosaic/gpu/layout_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,7 @@ def _pointwise_op_constraint_system(
604604
arith.XOrIOp,
605605
mlir_math.ExpOp,
606606
mlir_math.Exp2Op,
607+
mlir_math.SinOp,
607608
mlir_math.LogOp,
608609
mlir_math.RsqrtOp,
609610
mlir_math.TanhOp,

tests/pallas/ops_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,8 @@ def kernel(x_ref, o_ref):
10831083
for fn, dtype in itertools.product(*args)
10841084
)
10851085
def test_elementwise(self, fn, dtype):
1086-
self.skip_if_mosaic_gpu()
1086+
if fn is not jnp.sin or dtype == "float64":
1087+
self.skip_if_mosaic_gpu()
10871088

10881089
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
10891090
self.skipTest("64-bit types require x64_enabled")

0 commit comments

Comments
 (0)