Skip to content

Commit b6aa12a

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Make convert_element_type_p a standard primitive and remove the usage of utils.standard_abstract_eval.
PiperOrigin-RevId: 815067054
1 parent ed07ee7 commit b6aa12a

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

jax/_src/lax/lax.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@
6161
from jax._src.lax import utils as lax_utils
6262
from jax._src.mesh import get_abstract_mesh, get_concrete_mesh
6363
from jax._src.lax.utils import (
64-
input_dtype, dtype_to_string, standard_abstract_eval,
65-
standard_multi_result_abstract_eval, standard_primitive)
64+
input_dtype, dtype_to_string, standard_multi_result_abstract_eval,
65+
standard_primitive)
6666
from jax._src.lib.mlir import ir
6767
from jax._src.lib.mlir.dialects import chlo
6868
from jax._src.lib.mlir.dialects import hlo
@@ -4972,7 +4972,11 @@ def _convert_elt_type_pp_rule(eqn, context, settings):
49724972
del params['sharding'] # don't show trivial case
49734973
return core._pp_eqn(eqn.replace(params=params), context, settings)
49744974

4975-
convert_element_type_p = Primitive('convert_element_type')
4975+
convert_element_type_p = standard_primitive(
4976+
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
4977+
'convert_element_type', weak_type_rule=_convert_element_type_weak_type_rule,
4978+
sharding_rule=_convert_element_type_sharding_rule,
4979+
vma_rule=partial(core.standard_vma_rule, 'convert_element_type'))
49764980

49774981
# TODO(dougalm): I'm overriding bind_with_trace here because that's the closest thing to
49784982
# the old "custom bind" but it might not be the best way to do this.
@@ -4986,13 +4990,6 @@ def _convert_element_type_bind_with_trace(trace, args, params):
49864990
convert_element_type_p.def_bind_with_trace(_convert_element_type_bind_with_trace)
49874991

49884992
convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p))
4989-
convert_element_type_p.def_abstract_eval(
4990-
partial(standard_abstract_eval, convert_element_type_p,
4991-
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
4992-
_convert_element_type_weak_type_rule,
4993-
_convert_element_type_sharding_rule,
4994-
partial(core.standard_vma_rule, convert_element_type_p.name),
4995-
None, None))
49964993
ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule)
49974994
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
49984995

0 commit comments

Comments
 (0)