@@ -1350,11 +1350,26 @@ def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo, device_id_dict,
13501350 )
13511351 axes_dimensions = [mesh_axis_sizes [name ] for name in axis ]
13521352 for axis_index , axis_name in enumerate (axis ):
1353- axis_size = arith .constant (i32 , mesh_axis_sizes [axis_name ])
1354- minor_divisor = arith .constant (
1355- i32 , math .prod (axes_dimensions [axis_index + 1 :])
1356- )
1357- device_idx = arith .remsi (arith .divsi (idx , minor_divisor ), axis_size )
1353+ axis_size = mesh_axis_sizes [axis_name ]
1354+ inner_mesh_size = math .prod (axes_dimensions [axis_index + 1 :])
1355+ minor_divisor = arith .constant (i32 , inner_mesh_size )
1356+
1357+ # Fast path for power of 2s
1358+ if inner_mesh_size & (inner_mesh_size - 1 ) == 0 :
1359+ shift_len = (inner_mesh_size & - inner_mesh_size ).bit_length () - 1
1360+ partial_device_idx = arith .shrui (idx , arith .constant (i32 , shift_len ))
1361+ else :
1362+ partial_device_idx = arith .divsi (idx , minor_divisor )
1363+
1364+ if axis_size & (axis_size - 1 ) == 0 :
1365+ device_idx = arith .andi (
1366+ partial_device_idx ,
1367+ arith .constant (i32 , mesh_axis_sizes [axis_name ] - 1 ),
1368+ )
1369+ else :
1370+ device_idx = arith .remsi (
1371+ partial_device_idx , arith .constant (i32 , axis_size )
1372+ )
13581373 physical_axis_dict [axis_name ] = device_idx
13591374 else :
13601375 physical_axis_dict [axis ] = idx
0 commit comments