Skip to content

Commit 4e15a87

Browse files
Merge pull request #32218 from mwhittaker:fault_tolerance_docs
PiperOrigin-RevId: 842328162
2 parents eeab9e4 + 3a36656 commit 4e15a87

File tree

11 files changed

+2756
-0
lines changed

11 files changed

+2756
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
os.environ['XLA_FLAGS'] = ' '.join([
17+
'--xla_gpu_nccl_terminate_on_error=false',
18+
'--xla_gpu_nccl_async_execution=true',
19+
'--xla_gpu_nccl_blocking_communicators=false',
20+
])
21+
os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE'] = '1'
22+
os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT'] = '1'
23+
24+
from absl import app
25+
from absl import flags
26+
from collections.abc import Sequence
27+
import jax
28+
import jax.numpy as jnp
29+
import time
30+
31+
_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id")
32+
_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes")
33+
34+
35+
def main(_: Sequence[str]) -> None:
36+
jax.config.update("jax_enable_recoverability", True)
37+
jax.distributed.initialize(
38+
coordinator_address="localhost:9000",
39+
num_processes=_NUM_PROCESSES.value,
40+
process_id=_PROCESS_ID.value,
41+
local_device_ids=[_PROCESS_ID.value],
42+
heartbeat_timeout_seconds=10,
43+
)
44+
print(f'{jax.devices()=}')
45+
print(f'{jax.local_devices()=}')
46+
47+
# Don't do this. Use live_devices instead.
48+
from jax.experimental.multihost_utils import _live_devices
49+
_live_devices(jax._src.distributed.global_state.client, jax.devices())
50+
51+
n = jax.device_count()
52+
jax.set_mesh(jax.make_mesh((n,), ("i",)))
53+
x = jax.device_put(jnp.arange(n), jax.P("i"))
54+
while True:
55+
print(jnp.sum(x))
56+
time.sleep(1)
57+
58+
59+
if __name__ == "__main__":
60+
app.run(main)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
os.environ['XLA_FLAGS'] = '--xla_gpu_nccl_terminate_on_error=false'
17+
18+
from absl import app
19+
from absl import flags
20+
from collections.abc import Sequence
21+
import jax
22+
import jax.numpy as jnp
23+
import time
24+
25+
_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id")
26+
_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes")
27+
28+
29+
def main(_: Sequence[str]) -> None:
30+
jax.config.update("jax_enable_recoverability", True)
31+
jax.distributed.initialize(
32+
coordinator_address="localhost:9000",
33+
num_processes=_NUM_PROCESSES.value,
34+
process_id=_PROCESS_ID.value,
35+
local_device_ids=[_PROCESS_ID.value],
36+
heartbeat_timeout_seconds=10,
37+
)
38+
print(f'{jax.devices()=}')
39+
print(f'{jax.local_devices()=}')
40+
41+
n = jax.device_count()
42+
jax.set_mesh(jax.make_mesh((n,), ("i",)))
43+
x = jax.device_put(jnp.arange(n), jax.P("i"))
44+
while True:
45+
print(jnp.sum(x))
46+
time.sleep(1)
47+
48+
49+
if __name__ == "__main__":
50+
app.run(main)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
os.environ['XLA_FLAGS'] = ' '.join([
17+
'--xla_gpu_nccl_terminate_on_error=false',
18+
'--xla_gpu_nccl_async_execution=true',
19+
'--xla_gpu_nccl_blocking_communicators=false',
20+
])
21+
os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE'] = '1'
22+
os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT'] = '1'
23+
24+
from absl import app
25+
from absl import flags
26+
from collections.abc import Sequence
27+
from jax.experimental.multihost_utils import live_devices
28+
import jax
29+
import jax.numpy as jnp
30+
import time
31+
32+
_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id")
33+
_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes")
34+
35+
def replicated(x: jax.Array, devices: list[jax.Device]):
36+
"""Return x replicated across the provided devices.
37+
38+
Note that replicated(x) doesn't actually move any data. It simply creates a
39+
logically replicated array with x as the local replica.
40+
"""
41+
n = len(devices)
42+
mesh = jax.make_mesh((n, ), ("i", ), devices=devices)
43+
spec = jax.sharding.PartitionSpec(None)
44+
sharding = jax.sharding.NamedSharding(mesh, spec)
45+
shards = [
46+
jax.device_put(x.addressable_shards[0].data, d) for d in devices
47+
if d.process_index == jax.process_index()
48+
]
49+
return jax.make_array_from_single_device_arrays(x.shape, sharding, shards)
50+
51+
52+
def sharded(x: jax.Array, devices: list[jax.Device]):
53+
"""Return x sharded across the provided devices.
54+
55+
Note that sharded(x) doesn't actually move any data. It simply creates a
56+
logically sharded array. x should have the same shape as the global array.
57+
"""
58+
n = len(devices)
59+
mesh = jax.make_mesh((n, ), ("i", ), devices=devices)
60+
spec = jax.sharding.PartitionSpec("i")
61+
sharding = jax.sharding.NamedSharding(mesh, spec)
62+
m = sharding.addressable_devices_indices_map(x.shape)
63+
shards = [jax.device_put(x[m[d]], d) for d in jax.local_devices()]
64+
return jax.make_array_from_single_device_arrays(x.shape, sharding, shards)
65+
66+
67+
def main(_: Sequence[str]) -> None:
68+
# Parse command line arguments and initialize multi-controller JAX.
69+
jax.config.update("jax_enable_recoverability", True)
70+
jax.distributed.initialize(coordinator_address="localhost:8000",
71+
process_id=_PROCESS_ID.value,
72+
num_processes=_NUM_PROCESSES.value,
73+
local_device_ids=[_PROCESS_ID.value],
74+
heartbeat_timeout_seconds=10)
75+
print(f'{jax.devices()=}')
76+
print(f'{jax.local_devices()=}')
77+
78+
# Initialize the model's weights.
79+
keys = iter(jax.random.split(jax.random.key(seed=42), num=3))
80+
weights = jax.random.normal(next(keys), shape=(1, ))
81+
82+
# We'll learn a trivial linear model: a*x.
83+
def predict(weights, X):
84+
return weights * X
85+
86+
# We'll use mean squared error loss.
87+
def loss(weights, X, Y):
88+
return jnp.mean((predict(weights, X) - Y)**2)
89+
90+
# Initialize the (noisy) training data with a=10.
91+
X = jax.random.permutation(next(keys), jnp.arange(-300., 300.))
92+
Y = 10 * X + jax.random.normal(next(keys), X.shape)
93+
94+
# Hyperparameters.
95+
loss_and_grad = jax.jit(jax.value_and_grad(loss))
96+
learning_rate = 1e-6
97+
device_batch_size = 10
98+
99+
step = 0
100+
while True:
101+
try:
102+
with live_devices(jax.devices()) as devices:
103+
print(f'=== Running step {step} with live devices = {devices} ===')
104+
105+
# Replicate the model weights.
106+
weights = replicated(weights, devices)
107+
108+
# Shard the batch.
109+
batch_size = device_batch_size * len(devices)
110+
start = (step * batch_size) % len(X)
111+
stop = start + batch_size
112+
X_batch = sharded(X[start:stop], devices)
113+
Y_batch = sharded(Y[start:stop], devices)
114+
115+
# Compute gradients and update weights.
116+
l, grad = loss_and_grad(weights, X_batch, Y_batch)
117+
new_weights = jax.block_until_ready(weights - learning_rate * grad)
118+
except Exception as e:
119+
print(f'Step {step} failed: {e}')
120+
else:
121+
print(f'Step {step} succeeded: loss = {l}')
122+
step += 1
123+
weights = new_weights
124+
125+
time.sleep(1)
126+
127+
128+
if __name__ == "__main__":
129+
app.run(main)

0 commit comments

Comments
 (0)