@@ -8184,12 +8184,17 @@ def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
81848184}
81858185
81868186
8187- def _sort_abstract_eval (* args , ** kwargs ):
8188- args = tuple (args )
8189- if any (arg .shape != args [0 ].shape for arg in args [1 :]):
8190- shapes = " " .join (str (a .shape ) for a in args )
8187+ def _sort_abstract_eval (* avals , ** kwargs ):
8188+ avals = tuple (avals )
8189+ if any (arg .shape != avals [0 ].shape for arg in avals [1 :]):
8190+ shapes = " " .join (str (a .shape ) for a in avals )
81918191 raise TypeError (f"Arguments to sort must have equal shapes, got: { shapes } " )
8192- return args
8192+ non_empty_s = [a .sharding for a in avals if not a .sharding .mesh .empty ]
8193+ if any (s != non_empty_s [0 ] for s in non_empty_s [1 :]):
8194+ shardings = " " .join (str (s ) for s in non_empty_s )
8195+ raise core .ShardingTypeError (
8196+ f'Arguments to sort must have equal shardings, got: { shardings } ' )
8197+ return avals
81938198
81948199
81958200def _canonicalize_float_for_sort (x ):
@@ -8287,7 +8292,9 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys
82878292 for arg , bdim in zip (batched_args , batch_dims ):
82888293 if bdim is None :
82898294 dims = np .delete (np .arange (prototype_arg .ndim ), new_bdim )
8290- new_args .append (broadcast_in_dim (arg , prototype_arg .shape , dims ))
8295+ new_args .append (broadcast_in_dim (
8296+ arg , prototype_arg .shape , dims ,
8297+ out_sharding = core .typeof (prototype_arg ).sharding ))
82918298 else :
82928299 new_args .append (batching .moveaxis (arg , bdim , new_bdim ))
82938300 new_dimension = dimension + (new_bdim <= dimension )
0 commit comments