|
| 1 | +# Copyright 2025 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import os |
| 16 | +os.environ['XLA_FLAGS'] = ' '.join([ |
| 17 | + '--xla_gpu_nccl_terminate_on_error=false', |
| 18 | + '--xla_gpu_nccl_async_execution=true', |
| 19 | + '--xla_gpu_nccl_blocking_communicators=false', |
| 20 | +]) |
| 21 | +os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE'] = '1' |
| 22 | +os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT'] = '1' |
| 23 | + |
| 24 | +from absl import app |
| 25 | +from absl import flags |
| 26 | +from collections.abc import Sequence |
| 27 | +from jax.experimental.multihost_utils import live_devices |
| 28 | +import jax |
| 29 | +import jax.numpy as jnp |
| 30 | +import time |
| 31 | + |
| 32 | +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") |
| 33 | +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") |
| 34 | + |
| 35 | +def replicated(x: jax.Array, devices: list[jax.Device]): |
| 36 | + """Return x replicated across the provided devices. |
| 37 | +
|
| 38 | + Note that replicated(x) doesn't actually move any data. It simply creates a |
| 39 | + logically replicated array with x as the local replica. |
| 40 | + """ |
| 41 | + n = len(devices) |
| 42 | + mesh = jax.make_mesh((n, ), ("i", ), devices=devices) |
| 43 | + spec = jax.sharding.PartitionSpec(None) |
| 44 | + sharding = jax.sharding.NamedSharding(mesh, spec) |
| 45 | + shards = [ |
| 46 | + jax.device_put(x.addressable_shards[0].data, d) for d in devices |
| 47 | + if d.process_index == jax.process_index() |
| 48 | + ] |
| 49 | + return jax.make_array_from_single_device_arrays(x.shape, sharding, shards) |
| 50 | + |
| 51 | + |
| 52 | +def sharded(x: jax.Array, devices: list[jax.Device]): |
| 53 | + """Return x sharded across the provided devices. |
| 54 | +
|
| 55 | + Note that sharded(x) doesn't actually move any data. It simply creates a |
| 56 | + logically sharded array. x should have the same shape as the global array. |
| 57 | + """ |
| 58 | + n = len(devices) |
| 59 | + mesh = jax.make_mesh((n, ), ("i", ), devices=devices) |
| 60 | + spec = jax.sharding.PartitionSpec("i") |
| 61 | + sharding = jax.sharding.NamedSharding(mesh, spec) |
| 62 | + m = sharding.addressable_devices_indices_map(x.shape) |
| 63 | + shards = [jax.device_put(x[m[d]], d) for d in jax.local_devices()] |
| 64 | + return jax.make_array_from_single_device_arrays(x.shape, sharding, shards) |
| 65 | + |
| 66 | + |
| 67 | +def main(_: Sequence[str]) -> None: |
| 68 | + # Parse command line arguments and initialize multi-controller JAX. |
| 69 | + jax.config.update("jax_enable_recoverability", True) |
| 70 | + jax.distributed.initialize(coordinator_address="localhost:8000", |
| 71 | + process_id=_PROCESS_ID.value, |
| 72 | + num_processes=_NUM_PROCESSES.value, |
| 73 | + local_device_ids=[_PROCESS_ID.value], |
| 74 | + heartbeat_timeout_seconds=10) |
| 75 | + print(f'{jax.devices()=}') |
| 76 | + print(f'{jax.local_devices()=}') |
| 77 | + |
| 78 | + # Initialize the model's weights. |
| 79 | + keys = iter(jax.random.split(jax.random.key(seed=42), num=3)) |
| 80 | + weights = jax.random.normal(next(keys), shape=(1, )) |
| 81 | + |
| 82 | + # We'll learn a trivial linear model: a*x. |
| 83 | + def predict(weights, X): |
| 84 | + return weights * X |
| 85 | + |
| 86 | + # We'll use mean squared error loss. |
| 87 | + def loss(weights, X, Y): |
| 88 | + return jnp.mean((predict(weights, X) - Y)**2) |
| 89 | + |
| 90 | + # Initialize the (noisy) training data with a=10. |
| 91 | + X = jax.random.permutation(next(keys), jnp.arange(-300., 300.)) |
| 92 | + Y = 10 * X + jax.random.normal(next(keys), X.shape) |
| 93 | + |
| 94 | + # Hyperparameters. |
| 95 | + loss_and_grad = jax.jit(jax.value_and_grad(loss)) |
| 96 | + learning_rate = 1e-6 |
| 97 | + device_batch_size = 10 |
| 98 | + |
| 99 | + step = 0 |
| 100 | + while True: |
| 101 | + try: |
| 102 | + with live_devices(jax.devices()) as devices: |
| 103 | + print(f'=== Running step {step} with live devices = {devices} ===') |
| 104 | + |
| 105 | + # Replicate the model weights. |
| 106 | + weights = replicated(weights, devices) |
| 107 | + |
| 108 | + # Shard the batch. |
| 109 | + batch_size = device_batch_size * len(devices) |
| 110 | + start = (step * batch_size) % len(X) |
| 111 | + stop = start + batch_size |
| 112 | + X_batch = sharded(X[start:stop], devices) |
| 113 | + Y_batch = sharded(Y[start:stop], devices) |
| 114 | + |
| 115 | + # Compute gradients and update weights. |
| 116 | + l, grad = loss_and_grad(weights, X_batch, Y_batch) |
| 117 | + new_weights = jax.block_until_ready(weights - learning_rate * grad) |
| 118 | + except Exception as e: |
| 119 | + print(f'Step {step} failed: {e}') |
| 120 | + else: |
| 121 | + print(f'Step {step} succeeded: loss = {l}') |
| 122 | + step += 1 |
| 123 | + weights = new_weights |
| 124 | + |
| 125 | + time.sleep(1) |
| 126 | + |
| 127 | + |
| 128 | +if __name__ == "__main__": |
| 129 | + app.run(main) |
0 commit comments