6161from jax ._src .lax import utils as lax_utils
6262from jax ._src .mesh import get_abstract_mesh , get_concrete_mesh
6363from jax ._src .lax .utils import (
64- input_dtype , dtype_to_string , standard_abstract_eval ,
65- standard_multi_result_abstract_eval , standard_primitive )
64+ input_dtype , dtype_to_string , standard_multi_result_abstract_eval ,
65+ standard_primitive )
6666from jax ._src .lib .mlir import ir
6767from jax ._src .lib .mlir .dialects import chlo
6868from jax ._src .lib .mlir .dialects import hlo
@@ -4972,7 +4972,11 @@ def _convert_elt_type_pp_rule(eqn, context, settings):
49724972 del params ['sharding' ] # don't show trivial case
49734973 return core ._pp_eqn (eqn .replace (params = params ), context , settings )
49744974
4975- convert_element_type_p = Primitive ('convert_element_type' )
4975+ convert_element_type_p = standard_primitive (
4976+ _convert_element_type_shape_rule , _convert_element_type_dtype_rule ,
4977+ 'convert_element_type' , weak_type_rule = _convert_element_type_weak_type_rule ,
4978+ sharding_rule = _convert_element_type_sharding_rule ,
4979+ vma_rule = partial (core .standard_vma_rule , 'convert_element_type' ))
49764980
49774981# TODO(dougalm): I'm overriding bind_with_trace here because that's the closest thing to
49784982# the old "custom bind" but it might not be the best way to do this.
@@ -4986,13 +4990,6 @@ def _convert_element_type_bind_with_trace(trace, args, params):
49864990convert_element_type_p .def_bind_with_trace (_convert_element_type_bind_with_trace )
49874991
49884992convert_element_type_p .def_impl (partial (dispatch .apply_primitive , convert_element_type_p ))
4989- convert_element_type_p .def_abstract_eval (
4990- partial (standard_abstract_eval , convert_element_type_p ,
4991- _convert_element_type_shape_rule , _convert_element_type_dtype_rule ,
4992- _convert_element_type_weak_type_rule ,
4993- _convert_element_type_sharding_rule ,
4994- partial (core .standard_vma_rule , convert_element_type_p .name ),
4995- None , None ))
49964993ad .defjvp2 (convert_element_type_p , _convert_element_type_jvp_rule )
49974994ad .primitive_transposes [convert_element_type_p ] = _convert_element_type_transpose_rule
49984995
0 commit comments