Skip to content

Commit c5c8af2

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Add support for jnp.cos
PiperOrigin-RevId: 841864222
1 parent 32d830f commit c5c8af2

File tree

4 files changed

+15
-1
lines changed

4 files changed

+15
-1
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,6 +2315,18 @@ def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
23152315
)
23162316
return math_dialect.sin(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
23172317

2318+
@register_lowering_rule(lax.cos_p, mgpu.LoweringSemantics.Lane)
2319+
@register_lowering_rule(lax.cos_p, mgpu.LoweringSemantics.Warpgroup)
2320+
def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
2321+
if accuracy is not None:
2322+
raise NotImplementedError("Not implemented: accuracy")
2323+
[x_aval] = ctx.avals_in
2324+
if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
2325+
return _ensure_fa(x, x_aval.dtype).cos(approx=ctx.module_ctx.approx_math)
2326+
fastmath = (
2327+
arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
2328+
)
2329+
return math_dialect.cos(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
23182330

23192331
@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Lane)
23202332
@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Warpgroup)

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,7 @@ def _unary_op_lowering_rule(
11001100
(mlir_math.ExpOp, fa.FragmentedArray.exp, None),
11011101
(mlir_math.Exp2Op, fa.FragmentedArray.exp2, None),
11021102
(mlir_math.SinOp, fa.FragmentedArray.sin, None),
1103+
(mlir_math.CosOp, fa.FragmentedArray.cos, None),
11031104
(mlir_math.LogOp, fa.FragmentedArray.log, None),
11041105
(mlir_math.TanhOp, fa.FragmentedArray.tanh, None),
11051106
]:

jax/experimental/mosaic/gpu/layout_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,7 @@ def _pointwise_op_constraint_system(
605605
mlir_math.ExpOp,
606606
mlir_math.Exp2Op,
607607
mlir_math.SinOp,
608+
mlir_math.CosOp,
608609
mlir_math.LogOp,
609610
mlir_math.RsqrtOp,
610611
mlir_math.TanhOp,

tests/pallas/ops_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ def kernel(x_ref, o_ref):
10831083
for fn, dtype in itertools.product(*args)
10841084
)
10851085
def test_elementwise(self, fn, dtype):
1086-
if fn is not jnp.sin or dtype == "float64":
1086+
if fn not in (jnp.sin, jnp.cos) or dtype == "float64":
10871087
self.skip_if_mosaic_gpu()
10881088

10891089
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:

0 commit comments

Comments
 (0)