Skip to content

Commit 32d830f

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Make sure all args passed to lax.sort have the same sharding just like shapes.
PiperOrigin-RevId: 841855619
1 parent 599eb7b commit 32d830f

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

jax/_src/lax/lax.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8184,12 +8184,17 @@ def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
81848184
}
81858185

81868186

8187-
def _sort_abstract_eval(*args, **kwargs):
8188-
args = tuple(args)
8189-
if any(arg.shape != args[0].shape for arg in args[1:]):
8190-
shapes = " ".join(str(a.shape) for a in args)
8187+
def _sort_abstract_eval(*avals, **kwargs):
8188+
avals = tuple(avals)
8189+
if any(arg.shape != avals[0].shape for arg in avals[1:]):
8190+
shapes = " ".join(str(a.shape) for a in avals)
81918191
raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}")
8192-
return args
8192+
non_empty_s = [a.sharding for a in avals if not a.sharding.mesh.empty]
8193+
if any(s != non_empty_s[0] for s in non_empty_s[1:]):
8194+
shardings = " ".join(str(s) for s in non_empty_s)
8195+
raise core.ShardingTypeError(
8196+
f'Arguments to sort must have equal shardings, got: {shardings}')
8197+
return avals
81938198

81948199

81958200
def _canonicalize_float_for_sort(x):
@@ -8287,7 +8292,9 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys
82878292
for arg, bdim in zip(batched_args, batch_dims):
82888293
if bdim is None:
82898294
dims = np.delete(np.arange(prototype_arg.ndim), new_bdim)
8290-
new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims))
8295+
new_args.append(broadcast_in_dim(
8296+
arg, prototype_arg.shape, dims,
8297+
out_sharding=core.typeof(prototype_arg).sharding))
82918298
else:
82928299
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
82938300
new_dimension = dimension + (new_bdim <= dimension)

0 commit comments

Comments
 (0)