@@ -1345,7 +1345,13 @@ def pmap(
13451345 donate_argnums : int | Iterable [int ] = (),
13461346 global_arg_shapes : tuple [tuple [int , ...], ...] | None = None ,
13471347 ) -> Any :
1348- """Parallel map with support for collective operations.
1348+ """Old way of doing parallel map. Use :py:func:`jax.shard_map` instead.
1349+
1350+ .. note::
1351+ While :py:func:`jax.pmap` works, you should probably use
1352+ :py:func:`jax.shard_map` or ``jax.smap`` instead. shard_map supports more
1353+ efficient autodiff, and is more composable in the multi-controller setting.
1354+ See https://docs.jax.dev/en/latest/notebooks/shard_map.html for examples.
13491355
13501356 .. note::
13511357 :py:func:`pmap` is now implemented in terms of :py:func:`jit` and
@@ -1510,26 +1516,6 @@ def pmap(
15101516 are important particularly in the case of nested :py:func:`pmap` functions,
15111517 where collective operations can operate over distinct axes:
15121518
1513- >>> from functools import partial
1514- >>> import jax
1515- >>>
1516- >>> @partial(pmap, axis_name='rows')
1517- ... @partial(pmap, axis_name='cols')
1518- ... def normalize(x):
1519- ... row_normed = x / jax.lax.psum(x, 'rows')
1520- ... col_normed = x / jax.lax.psum(x, 'cols')
1521- ... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
1522- ... return row_normed, col_normed, doubly_normed
1523- >>>
1524- >>> x = jnp.arange(8.).reshape((4, 2))
1525- >>> row_normed, col_normed, doubly_normed = normalize(x) # doctest: +SKIP
1526- >>> print(row_normed.sum(0)) # doctest: +SKIP
1527- [ 1. 1.]
1528- >>> print(col_normed.sum(1)) # doctest: +SKIP
1529- [ 1. 1. 1. 1.]
1530- >>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP
1531- 1.0
1532-
15331519 On multi-process platforms, collective operations operate over all devices,
15341520 including those on other processes. For example, assuming the following code
15351521 runs on two processes with 4 XLA devices each:
0 commit comments