Skip to content

Commit 1b94fa5

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Assign a mesh to ShapedArray on out_shape in pallas_call so that vma's make sense.
PiperOrigin-RevId: 815222246
1 parent 4e909f0 commit 1b94fa5

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

jax/_src/pallas/pallas_call.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1372,7 +1372,8 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
13721372
" argument of `jax.ShapeDtypeStruct` or set `check_vma=False` on"
13731373
" `jax.shard_map`.")
13741374
return jax_core.ShapedArray(
1375-
shape=out_shape.shape, dtype=out_shape.dtype, vma=out_shape.vma)
1375+
shape=out_shape.shape, dtype=out_shape.dtype,
1376+
sharding=jax_core.get_cur_mesh_sharding(), vma=out_shape.vma)
13761377
return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
13771378
case pallas_core.MemoryRef():
13781379
return out_shape.get_array_aval()

0 commit comments

Comments
 (0)