Skip to content

Commit 41e4c25

Browse files
committed
[ROCm] Add float8_e4m3fnuz and float8_e5m2fnuz support for Rocm
1 parent 3f1b059 commit 41e4c25

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

jax/_src/lax/lax.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2913,8 +2913,9 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
29132913
precision, preferred_element_type: np.dtype | None,
29142914
platform: str = "default"):
29152915
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
2916-
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2)
2917-
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
2916+
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
2917+
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
2918+
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
29182919
del preferred_element_type # Implied by the output aval
29192920
lhs_aval, rhs_aval = ctx.avals_in
29202921
lhs_dtype, rhs_dtype = lhs_aval.dtype, rhs_aval.dtype

tests/lax_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,11 +1079,18 @@ def testDotPreferredElement(self, lhs_shape, rhs_shape, dtype,
10791079
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]],
10801080
[dict(dtype_lhs=dtype_lhs, dtype_rhs=dtype_rhs)
10811081
for dtype_lhs, dtype_rhs in [(dtypes.float8_e4m3fn, dtypes.float8_e5m2),
1082-
(dtypes.float8_e5m2, dtypes.float8_e4m3fn)]],
1082+
(dtypes.float8_e5m2, dtypes.float8_e4m3fn),
1083+
(dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz),
1084+
(dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)]],
10831085
)
10841086
def test_mixed_fp8_dot_general(self, lhs_shape, rhs_shape, dtype_lhs, dtype_rhs):
10851087
if jtu.test_device_matches(["tpu"]):
10861088
raise SkipTest("Mixed fp8 precision matmul is not yet supported on TPU")
1089+
if not jtu.is_device_rocm() and (
1090+
dtype_lhs in [dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz] or
1091+
dtype_rhs in [dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz]
1092+
):
1093+
raise SkipTest("float8_e4m3fnuz and float8_e5m2fnuz types are only supported on ROCm")
10871094
rng = jtu.rand_default(self.rng())
10881095
lhs = rng(lhs_shape, dtype=dtype_lhs)
10891096
rhs = rng(rhs_shape, dtype=dtype_rhs)

0 commit comments

Comments
 (0)