We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4e909f0 commit 1b94fa5Copy full SHA for 1b94fa5
jax/_src/pallas/pallas_call.py
@@ -1372,7 +1372,8 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
1372
" argument of `jax.ShapeDtypeStruct` or set `check_vma=False` on"
1373
" `jax.shard_map`.")
1374
return jax_core.ShapedArray(
1375
- shape=out_shape.shape, dtype=out_shape.dtype, vma=out_shape.vma)
+ shape=out_shape.shape, dtype=out_shape.dtype,
1376
+ sharding=jax_core.get_cur_mesh_sharding(), vma=out_shape.vma)
1377
return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
1378
case pallas_core.MemoryRef():
1379
return out_shape.get_array_aval()
0 commit comments