Skip to content

Commit f7e3bdb

Browse files
yashk2810mattjj
authored andcommitted
Mention pmap is in maintenance mode and point to shard_map and the migration guide
Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 842298387
1 parent 64b16f0 commit f7e3bdb

File tree

3 files changed

+7
-25
lines changed

3 files changed

+7
-25
lines changed

docs/jax.sharding.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ Classes
1616
.. autoclass:: NamedSharding
1717
:members:
1818
:show-inheritance:
19-
.. autoclass:: PmapSharding
20-
:members:
21-
:show-inheritance:
2219
.. autoclass:: PartitionSpec
2320
:members:
2421
.. autoclass:: Mesh

jax/_src/api.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,7 +1345,13 @@ def pmap(
13451345
donate_argnums: int | Iterable[int] = (),
13461346
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
13471347
) -> Any:
1348-
"""Parallel map with support for collective operations.
1348+
"""Old way of doing parallel map. Use :py:func:`jax.shard_map` instead.
1349+
1350+
.. note::
1351+
While :py:func:`jax.pmap` works, you should probably use
1352+
:py:func:`jax.shard_map` or ``jax.smap`` instead. shard_map supports more
1353+
efficient autodiff, and is more composable in the multi-controller setting.
1354+
See https://docs.jax.dev/en/latest/notebooks/shard_map.html for examples.
13491355
13501356
.. note::
13511357
:py:func:`pmap` is now implemented in terms of :py:func:`jit` and
@@ -1510,26 +1516,6 @@ def pmap(
15101516
are important particularly in the case of nested :py:func:`pmap` functions,
15111517
where collective operations can operate over distinct axes:
15121518
1513-
>>> from functools import partial
1514-
>>> import jax
1515-
>>>
1516-
>>> @partial(pmap, axis_name='rows')
1517-
... @partial(pmap, axis_name='cols')
1518-
... def normalize(x):
1519-
... row_normed = x / jax.lax.psum(x, 'rows')
1520-
... col_normed = x / jax.lax.psum(x, 'cols')
1521-
... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
1522-
... return row_normed, col_normed, doubly_normed
1523-
>>>
1524-
>>> x = jnp.arange(8.).reshape((4, 2))
1525-
>>> row_normed, col_normed, doubly_normed = normalize(x) # doctest: +SKIP
1526-
>>> print(row_normed.sum(0)) # doctest: +SKIP
1527-
[ 1. 1.]
1528-
>>> print(col_normed.sum(1)) # doctest: +SKIP
1529-
[ 1. 1. 1. 1.]
1530-
>>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP
1531-
1.0
1532-
15331519
On multi-process platforms, collective operations operate over all devices,
15341520
including those on other processes. For example, assuming the following code
15351521
runs on two processes with 4 XLA devices each:

jax/_src/sharding_impls.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ def pmap_sharding_devices_indices_map(
193193

194194
@use_cpp_class(xc.PmapSharding)
195195
class PmapSharding(jsharding.Sharding):
196-
"""Describes a sharding used by :func:`jax.pmap`."""
197196
devices: np.ndarray
198197
sharding_spec: sharding_specs.ShardingSpec
199198
_internal_device_list: xc.DeviceList

0 commit comments

Comments
 (0)