Skip to content

Commit 9c97b57

Browse files
[Pallas] Device Id dict to mesh fastpath for power of twos
PiperOrigin-RevId: 841957791
1 parent 249b7a7 commit 9c97b57

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

jax/_src/pallas/primitives.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,11 +1350,26 @@ def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo, device_id_dict,
13501350
)
13511351
axes_dimensions = [mesh_axis_sizes[name] for name in axis]
13521352
for axis_index, axis_name in enumerate(axis):
1353-
axis_size = arith.constant(i32, mesh_axis_sizes[axis_name])
1354-
minor_divisor = arith.constant(
1355-
i32, math.prod(axes_dimensions[axis_index + 1 :])
1356-
)
1357-
device_idx = arith.remsi(arith.divsi(idx, minor_divisor), axis_size)
1353+
axis_size = mesh_axis_sizes[axis_name]
1354+
inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :])
1355+
minor_divisor = arith.constant(i32, inner_mesh_size)
1356+
1357+
# Fast path for power of 2s
1358+
if inner_mesh_size & (inner_mesh_size - 1) == 0:
1359+
shift_len = (inner_mesh_size & -inner_mesh_size).bit_length() - 1
1360+
partial_device_idx = arith.shrui(idx, arith.constant(i32, shift_len))
1361+
else:
1362+
partial_device_idx = arith.divsi(idx, minor_divisor)
1363+
1364+
if axis_size & (axis_size - 1) == 0:
1365+
device_idx = arith.andi(
1366+
partial_device_idx,
1367+
arith.constant(i32, mesh_axis_sizes[axis_name] - 1),
1368+
)
1369+
else:
1370+
device_idx = arith.remsi(
1371+
partial_device_idx, arith.constant(i32, axis_size)
1372+
)
13581373
physical_axis_dict[axis_name] = device_idx
13591374
else:
13601375
physical_axis_dict[axis] = idx

0 commit comments

Comments
 (0)