@@ -848,6 +848,23 @@ def f(x): # x: f32[b]
848848 a = exp2 .in_avals [0 ].shape [0 ]
849849 self .assertEqual (exp2 .out_avals [0 ].shape , output_shape (a ))
850850
851+ def test_poly_call_pmap (self ):
852+ if len (jax .devices ()) < 2 :
853+ self .skipTest ("Need at least 2 devices" )
854+ def f (x ): # x: f32[a, 4]
855+ return x + jnp .arange (x .shape [0 ], dtype = x .dtype ).reshape ((x .shape [0 ], 1 ))
856+
857+ a , = export .symbolic_shape ("a" )
858+ exp = export .export (f )(
859+ jax .ShapeDtypeStruct ((a , 4 ), np .float32 ))
860+ f_exp = export .call_exported (exp )
861+ x_jit = np .arange (12 , dtype = np .float32 ).reshape ((3 , 4 ))
862+ res_jit = jax .jit (f_exp )(x_jit )
863+ self .assertAllClose (res_jit , f (x_jit ))
864+ x_pmap = np .arange (24 , dtype = np .float32 ).reshape ((2 , 3 , 4 ))
865+ res_pmap = jax .pmap (f_exp )(x_pmap )
866+ self .assertAllClose (res_pmap , jnp .stack ([f (x ) for x in x_pmap ]))
867+
851868 def test_with_sharding (self ):
852869 nr_devices = 2
853870 if len (jax .devices ()) < nr_devices :
@@ -1204,7 +1221,6 @@ def f(x):
12041221 g_rev = jax .grad (export .call (exp ))(input )
12051222 self .assertAllClose (g , g_rev )
12061223
1207-
12081224 def test_multi_platform (self ):
12091225 x = np .arange (8 , dtype = np .float32 )
12101226 exp = get_exported (_testing_multi_platform_func ,
0 commit comments