@@ -69,14 +69,18 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh:
6969 m = a .sharding .mesh
7070 return mesh_lib .empty_abstract_mesh if m is None else m
7171
72- def call_unreduced_rule (prim , unreduced_rule , out_s , * avals , ** kwargs ):
72+ def call_unreduced_rule (prim , unreduced_rule , out_s , num_out , * avals , ** kwargs ):
7373 if unreduced_rule is not None :
7474 return unreduced_rule (out_s , * avals , ** kwargs )
7575
7676 if any (a .sharding .spec .unreduced for a in avals ):
7777 raise NotImplementedError (
7878 f'unreduced rule for { prim .name } is not implemented. Please file an'
7979 ' issue at https://github.com/jax-ml/jax/issues' )
80+ if any (s .spec .unreduced for s in ([out_s ] if num_out is None else out_s )):
81+ raise NotImplementedError (
82+ f'unreduced rule for { prim .name } is not implemented. Please file an'
83+ ' issue at https://github.com/jax-ml/jax/issues' )
8084 return out_s
8185
8286def call_sharding_rule (prim , sh_rule , unreduced_rule , num_out , * avals , ** kwargs ):
@@ -85,9 +89,11 @@ def call_sharding_rule(prim, sh_rule, unreduced_rule, num_out, *avals, **kwargs)
8589 if ((cur_mesh .empty or cur_mesh ._are_all_axes_auto_or_manual ) and
8690 (aval_mesh .empty or aval_mesh ._are_all_axes_auto_or_manual )):
8791 aval_mesh = cur_mesh if aval_mesh .empty else aval_mesh
88- s = NamedSharding (aval_mesh , P ())
89- s = call_unreduced_rule (prim , unreduced_rule , s , * avals , ** kwargs )
90- return s if num_out is None else [s ] * num_out
92+ out_s = NamedSharding (aval_mesh , P ())
93+ out_s = out_s if num_out is None else [out_s ] * num_out
94+ out_s = call_unreduced_rule (prim , unreduced_rule , out_s , num_out ,
95+ * avals , ** kwargs )
96+ return out_s
9197 if sh_rule is None :
9298 raise core .ShardingTypeError (
9399 f'sharding rule for { prim .name } is not implemented. Please file an'
@@ -96,7 +102,7 @@ def call_sharding_rule(prim, sh_rule, unreduced_rule, num_out, *avals, **kwargs)
96102 ' mode via: `jax.sharding.auto_axes(fun, out_shardings=...)`' )
97103 out_sharding = sh_rule (* avals , ** kwargs )
98104 out_sharding = call_unreduced_rule (prim , unreduced_rule , out_sharding ,
99- * avals , ** kwargs )
105+ num_out , * avals , ** kwargs )
100106 return out_sharding
101107
102108def call_shape_dtype_sharding_rule (prim , shape_rule , dtype_rule , sharding_rule ,
0 commit comments