Skip to content

Commit a7bce47

Browse files
author
jax authors
committed
Merge pull request #20705 from chaserileyroberts:chase/pbroadcast_channel_fix
PiperOrigin-RevId: 637986186
2 parents 2b6bcb5 + af6970e commit a7bce47

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

jax/_src/lax/parallel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -983,8 +983,10 @@ def source_to_front(group):
983983
return [group[source]] + list(group[:source]) + list(group[source + 1:])
984984
replica_groups = [source_to_front(group) for group in replica_groups]
985985
channel = ctx.module_context.new_channel()
986+
channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)
986987
return hlo.CollectiveBroadcastOp(
987-
x, replica_groups=_replica_groups_hlo(replica_groups)).results
988+
x, replica_groups=_replica_groups_hlo(replica_groups),
989+
channel_handle=channel_handle).results
988990

989991
pbroadcast_p = core.AxisPrimitive('pbroadcast')
990992
pbroadcast_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))

0 commit comments

Comments
 (0)