@@ -1079,11 +1079,18 @@ def testDotPreferredElement(self, lhs_shape, rhs_shape, dtype,
10791079 for lhs_shape in [(3 ,), (4 , 3 )] for rhs_shape in [(3 ,), (3 , 6 )]],
10801080 [dict (dtype_lhs = dtype_lhs , dtype_rhs = dtype_rhs )
10811081 for dtype_lhs , dtype_rhs in [(dtypes .float8_e4m3fn , dtypes .float8_e5m2 ),
1082- (dtypes .float8_e5m2 , dtypes .float8_e4m3fn )]],
1082+ (dtypes .float8_e5m2 , dtypes .float8_e4m3fn ),
1083+ (dtypes .float8_e4m3fnuz , dtypes .float8_e5m2fnuz ),
1084+ (dtypes .float8_e5m2fnuz , dtypes .float8_e4m3fnuz )]],
10831085 )
10841086 def test_mixed_fp8_dot_general (self , lhs_shape , rhs_shape , dtype_lhs , dtype_rhs ):
10851087 if jtu .test_device_matches (["tpu" ]):
10861088 raise SkipTest ("Mixed fp8 precision matmul is not yet supported on TPU" )
1089+ if not jtu .is_device_rocm () and (
1090+ dtype_lhs in [dtypes .float8_e4m3fnuz , dtypes .float8_e5m2fnuz ] or
1091+ dtype_rhs in [dtypes .float8_e4m3fnuz , dtypes .float8_e5m2fnuz ]
1092+ ):
1093+ raise SkipTest ("float8_e4m3fnuz and float8_e5m2fnuz types are only supported on ROCm" )
10871094 rng = jtu .rand_default (self .rng ())
10881095 lhs = rng (lhs_shape , dtype = dtype_lhs )
10891096 rhs = rng (rhs_shape , dtype = dtype_rhs )
0 commit comments