Skip to content

Commit af6970e

Browse files
Pipe channel handle
1 parent 08b1cef commit af6970e

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

jax/_src/lax/parallel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -983,8 +983,10 @@ def source_to_front(group):
983983
return [group[source]] + list(group[:source]) + list(group[source + 1:])
984984
replica_groups = [source_to_front(group) for group in replica_groups]
985985
channel = ctx.module_context.new_channel()
986+
channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)
986987
return hlo.CollectiveBroadcastOp(
987-
x, replica_groups=_replica_groups_hlo(replica_groups)).results
988+
x, replica_groups=_replica_groups_hlo(replica_groups),
989+
channel_handle=channel_handle).results
988990

989991
pbroadcast_p = core.AxisPrimitive('pbroadcast')
990992
pbroadcast_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))

0 commit comments

Comments
 (0)