We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 2b6bcb5 + af6970e commit a7bce47Copy full SHA for a7bce47
jax/_src/lax/parallel.py
@@ -983,8 +983,10 @@ def source_to_front(group):
983
return [group[source]] + list(group[:source]) + list(group[source + 1:])
984
replica_groups = [source_to_front(group) for group in replica_groups]
985
channel = ctx.module_context.new_channel()
986
+ channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)
987
return hlo.CollectiveBroadcastOp(
- x, replica_groups=_replica_groups_hlo(replica_groups)).results
988
+ x, replica_groups=_replica_groups_hlo(replica_groups),
989
+ channel_handle=channel_handle).results
990
991
pbroadcast_p = core.AxisPrimitive('pbroadcast')
992
pbroadcast_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
0 commit comments