-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Adding Tensor_layout for Tensor parallelism for Autosharding #21792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
buildwithsuhana
wants to merge
14
commits into
keras-team:master
Choose a base branch
from
buildwithsuhana:tensor_parallel
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
06bb3bb
Adding tensor layout for TP autosharding
buildwithsuhana 41f8025
formatting files
buildwithsuhana e74eab2
Updating the docstring
buildwithsuhana 2cddf39
refactoring the code
buildwithsuhana fee036e
Merge branch 'tensor_parallel' of https://github.com/buildwithsuhana/…
buildwithsuhana 9bed6e4
Merge branch 'keras-team:master' into tensor_parallel
buildwithsuhana 5365f14
fixing test
buildwithsuhana bc4d094
fixing test
buildwithsuhana 4d32e49
adding autoconfig and coordinated_optimizer
buildwithsuhana 119ac15
updating docstrings and code format
buildwithsuhana 7851615
refactored autoconfig to not use recursion
buildwithsuhana 4707c2b
updating docstrings
buildwithsuhana 45aa44c
removing redundancies
buildwithsuhana 8bb39f6
added tests for autoconfig and coordinated optimizer
buildwithsuhana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import functools | ||
| import os | ||
|
|
||
| import jax | ||
|
|
@@ -9,6 +10,8 @@ | |
| from keras.src import backend | ||
| from keras.src import testing | ||
| from keras.src.backend.config import is_nnx_enabled | ||
| from keras.src.backend.jax.core import all_gather | ||
| from keras.src.backend.jax.core import all_reduce | ||
|
|
||
| if is_nnx_enabled(): | ||
| from flax import nnx | ||
|
|
@@ -66,3 +69,78 @@ def test_keras_variable_nnx_split_merge_sync(self): | |
| state = jax.tree.map(lambda x: x + 1, state) | ||
| variable2 = nnx.merge(graphdef, state) | ||
| self.assertEqual(variable2._value, variable2.value) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| backend.backend() != "jax", | ||
| reason="JAX backend specific test for collective operations.", | ||
| ) | ||
| @pytest.mark.skipif( | ||
| jax.local_device_count() < 2, | ||
| reason="Requires multiple local devices for testing.", | ||
| ) | ||
|
Comment on lines
+78
to
+81
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests will never run because we don't have a setup with 2 accelerators right now (we have 1 and 8). You should move these tests to |
||
| class JaxCollectiveOpsTest(testing.TestCase): | ||
| def test_all_reduce_sum(self): | ||
| """Tests the all_reduce operation with the 'sum' reduction.""" | ||
| num_devices = jax.local_device_count() | ||
| local_value = 10.0 | ||
|
|
||
| local_inputs = jax.numpy.array([local_value] * num_devices) | ||
|
|
||
| @functools.partial( | ||
| jax.pmap, axis_name="all", devices=jax.devices("cpu") | ||
| ) | ||
| def reduce_sum_fn(x): | ||
| return all_reduce(x, op="sum", axis_name="all") | ||
|
|
||
| result = reduce_sum_fn(local_inputs) | ||
| expected_sum = local_value * num_devices | ||
|
|
||
| self.assertTrue(np.allclose(result, expected_sum)) | ||
| self.assertEqual(result.shape, (num_devices,)) | ||
|
|
||
| def test_all_reduce_mean(self): | ||
| """Tests the all_reduce operation with the 'mean' reduction.""" | ||
| num_devices = jax.local_device_count() | ||
| local_value = 10.0 | ||
|
|
||
| local_inputs = jax.numpy.array([local_value] * num_devices) | ||
|
|
||
| @functools.partial( | ||
| jax.pmap, axis_name="all", devices=jax.devices("cpu") | ||
| ) | ||
| def reduce_mean_fn(x): | ||
| return all_reduce(x, op="mean", axis_name="all") | ||
|
|
||
| result = reduce_mean_fn(local_inputs) | ||
| expected_mean = local_value | ||
|
|
||
| self.assertTrue(np.allclose(result, expected_mean)) | ||
| self.assertEqual(result.shape, (num_devices,)) | ||
|
|
||
| def test_all_gather(self): | ||
| """Tests the all_gather operation.""" | ||
| num_devices = jax.local_device_count() | ||
| local_data = np.arange(5) | ||
|
|
||
| local_inputs = jax.numpy.stack( | ||
| [local_data + (i * 5) for i in range(num_devices)] | ||
| ) | ||
|
|
||
| @functools.partial( | ||
| jax.pmap, axis_name="all", devices=jax.devices("cpu") | ||
| ) | ||
| def gather_fn(x): | ||
| return all_gather(x, axis=0, axis_name="all") | ||
|
|
||
| result_array_on_devices = gather_fn(local_inputs) | ||
|
|
||
| expected_shape = (num_devices, num_devices * local_data.shape[0]) | ||
| self.assertEqual(result_array_on_devices.shape, expected_shape) | ||
|
|
||
| expected_gathered_data = np.arange(num_devices * local_data.shape[0]) | ||
|
|
||
| for i in range(num_devices): | ||
| self.assertTrue( | ||
| np.allclose(result_array_on_devices[i], expected_gathered_data) | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| import collections | ||
|
|
||
| from keras.src import ops | ||
|
|
||
|
|
||
| def split_tensor_for_parallelism(tensor, index, device_count, dim): | ||
| """Calculates a slice of a tensor along a specified dimension for a | ||
| given index. | ||
| This utility is used in tensor parallelism API to distribute a | ||
| tensor across multiple devices. | ||
| Args: | ||
| tensor: The full tensor to be sharded. | ||
| index: The index of the device/shard to return (e.g., 0, 1, 2...). | ||
| device_count: The total number of parallel devices or splits. | ||
| dim: The dimension along which to split the tensor. If -1, the | ||
| last dimension is used. | ||
| Returns: | ||
| A tensor slice corresponding to the given `index`. | ||
| """ | ||
buildwithsuhana marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if dim == -1: | ||
| static_shape = getattr(tensor, "shape", None) | ||
| if static_shape is not None: | ||
| rank = len(static_shape) | ||
| else: | ||
| rank = None | ||
|
|
||
| if rank is not None: | ||
| split_dim = rank - 1 | ||
| else: | ||
| split_dim = ops.ndim(tensor) - 1 | ||
| else: | ||
| split_dim = dim | ||
buildwithsuhana marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| splits = ops.array_split( | ||
| tensor, indices_or_sections=device_count, axis=split_dim | ||
| ) | ||
| return splits[index] | ||
|
|
||
|
|
||
| LayoutMap = collections.namedtuple("LayoutMap", ["state_rules", "output_rules"]) | ||
163 changes: 163 additions & 0 deletions
163
keras/src/distribution/tensor_parallel/tensor_layout_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| from keras.src import ops | ||
| from keras.src import testing | ||
| from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap | ||
| from keras.src.distribution.tensor_parallel.tensor_layout import ( | ||
| split_tensor_for_parallelism, | ||
| ) | ||
|
|
||
|
|
||
| class LayoutTest(testing.TestCase): | ||
| """Test suite for tensor layout actions and mappings.""" | ||
|
|
||
| def test_split_with_even_division(self): | ||
| """Tests splitting a tensor that divides evenly among workers.""" | ||
| device_count = 4 | ||
| dim = 0 | ||
| tensor = ops.reshape(ops.arange(16, dtype="float32"), (8, 2)) | ||
|
|
||
| expected_shard_0 = ops.array([[0.0, 1.0], [2.0, 3.0]]) | ||
| expected_shard_2 = ops.array([[8.0, 9.0], [10.0, 11.0]]) | ||
|
|
||
| shard_0 = split_tensor_for_parallelism( | ||
| tensor, index=0, device_count=device_count, dim=dim | ||
| ) | ||
| shard_2 = split_tensor_for_parallelism( | ||
| tensor, index=2, device_count=device_count, dim=dim | ||
| ) | ||
|
|
||
| self.assertAllClose(shard_0, expected_shard_0) | ||
| self.assertAllClose(shard_2, expected_shard_2) | ||
| self.assertEqual(shard_0.shape, (2, 2)) | ||
|
|
||
| def test_split_with_uneven_division(self): | ||
| """Tests splitting tensor where remainder is distributed correctly.""" | ||
| device_count = 3 | ||
| dim = 0 | ||
| tensor = ops.reshape(ops.arange(10, dtype="float32"), (10, 1)) | ||
|
|
||
| shard_0 = split_tensor_for_parallelism( | ||
| tensor, index=0, device_count=device_count, dim=dim | ||
| ) | ||
| self.assertEqual(shard_0.shape, (4, 1)) | ||
| self.assertAllClose(shard_0, ops.array([[0.0], [1.0], [2.0], [3.0]])) | ||
|
|
||
| shard_1 = split_tensor_for_parallelism( | ||
| tensor, index=1, device_count=device_count, dim=dim | ||
| ) | ||
| self.assertEqual(shard_1.shape, (3, 1)) | ||
| self.assertAllClose(shard_1, ops.array([[4.0], [5.0], [6.0]])) | ||
|
|
||
| shard_2 = split_tensor_for_parallelism( | ||
| tensor, index=2, device_count=device_count, dim=dim | ||
| ) | ||
| self.assertEqual(shard_2.shape, (3, 1)) | ||
| self.assertAllClose(shard_2, ops.array([[7.0], [8.0], [9.0]])) | ||
|
|
||
| def test_split_and_undo_cycle_even_removed(self): | ||
| """ | ||
| Confirms that the original tensor can be reconstructed. | ||
| """ | ||
| device_count = 2 | ||
| dim = 0 | ||
| original_tensor = ops.reshape(ops.arange(12, dtype="float32"), (6, 2)) | ||
|
|
||
| shards = [ | ||
| split_tensor_for_parallelism( | ||
| original_tensor, index=i, device_count=device_count, dim=dim | ||
| ) | ||
| for i in range(device_count) | ||
| ] | ||
|
|
||
| reconstructed_tensor = ops.concatenate(shards, axis=dim) | ||
|
|
||
| self.assertAllClose(original_tensor, reconstructed_tensor) | ||
|
|
||
| def test_split_and_undo_cycle_uneven_removed(self): | ||
| """ | ||
| Confirms that original tensor can be reconstructed with uneven split. | ||
| """ | ||
| device_count = 4 | ||
| dim = 0 | ||
| original_tensor = ops.reshape(ops.arange(22, dtype="float32"), (11, 2)) | ||
|
|
||
| shards = [ | ||
| split_tensor_for_parallelism( | ||
| original_tensor, index=i, device_count=device_count, dim=dim | ||
| ) | ||
| for i in range(device_count) | ||
| ] | ||
|
|
||
| self.assertEqual(shards[0].shape, (3, 2)) | ||
| self.assertEqual(shards[1].shape, (3, 2)) | ||
| self.assertEqual(shards[2].shape, (3, 2)) | ||
| self.assertEqual(shards[3].shape, (2, 2)) | ||
|
|
||
| reconstructed_tensor = ops.concatenate(shards, axis=dim) | ||
| self.assertAllClose(original_tensor, reconstructed_tensor) | ||
|
|
||
| def test_split_last_dimension(self): | ||
| """Tests splitting on the last dimension using dim=-1.""" | ||
| device_count = 3 | ||
| dim = -1 | ||
| original_tensor = ops.reshape( | ||
| ops.arange(30, dtype="float32"), (2, 5, 3) | ||
| ) | ||
|
|
||
| shards = [ | ||
| split_tensor_for_parallelism( | ||
| original_tensor, index=i, device_count=device_count, dim=dim | ||
| ) | ||
| for i in range(device_count) | ||
| ] | ||
|
|
||
| self.assertEqual(shards[0].shape, (2, 5, 1)) | ||
| self.assertEqual(shards[1].shape, (2, 5, 1)) | ||
| self.assertEqual(shards[2].shape, (2, 5, 1)) | ||
|
|
||
| def test_split_with_sharding_type_hint(self): | ||
| """Tests using 'row' and 'column' sharding hints for 2D tensors.""" | ||
| device_count = 2 | ||
| tensor = ops.reshape(ops.arange(16, dtype="float32"), (4, 4)) | ||
|
|
||
| row_dim = 0 | ||
| shard_row_0 = split_tensor_for_parallelism( | ||
| tensor, index=0, device_count=device_count, dim=row_dim | ||
| ) | ||
| self.assertAllClose(shard_row_0, tensor[:2, :]) | ||
|
|
||
| col_dim = 1 | ||
| shard_col_0 = split_tensor_for_parallelism( | ||
| tensor, index=0, device_count=device_count, dim=col_dim | ||
| ) | ||
| self.assertAllClose(shard_col_0, tensor[:, :2]) | ||
|
|
||
| def test_layout_map_namedtuple_behavior(self): | ||
| """Tests basic behavior of the LayoutMap namedtuple.""" | ||
|
|
||
| def rule_kernel(tensor, index): | ||
| return split_tensor_for_parallelism( | ||
| tensor, index=index, device_count=2, dim=0 | ||
| ) | ||
|
|
||
| def rule_output(tensor, index): | ||
| return split_tensor_for_parallelism( | ||
| tensor, index=index, device_count=2, dim=-1 | ||
| ) | ||
|
|
||
| state_rules = {"kernel": rule_kernel} | ||
| output_rules = {"output": rule_output} | ||
|
|
||
| layout_map = LayoutMap( | ||
| state_rules=state_rules, output_rules=output_rules | ||
| ) | ||
|
|
||
| self.assertIs(layout_map.state_rules, state_rules) | ||
| self.assertIs(layout_map.output_rules, output_rules) | ||
|
|
||
| self.assertIs(layout_map[0], state_rules) | ||
| self.assertIs(layout_map[1], output_rules) | ||
|
|
||
| with self.assertRaises(AttributeError): | ||
| layout_map.state_rules = {} | ||
|
|
||
| self.assertTrue(callable(layout_map.state_rules["kernel"])) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that I see this, it's a bit weird to have these ops in
corebecause all backends must implement all ops incorebecause they are core ops.Also, this is distribution related. Can you move them to
keras/src/backend/jax/distribution_lib.pyat least for now.Maybe we'll need a specific namespace for distribution ops, but since they're not exported, I think
distribution_libis fine for now.