diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 15f1af2145d5..b22ea22547bb 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -37,6 +37,8 @@ if backend() == "tensorflow": from keras.src.backend.tensorflow import * # noqa: F403 from keras.src.backend.tensorflow.core import Variable as BackendVariable + + distributed_backend = None elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 from keras.src.backend.jax.core import Variable as BackendVariable @@ -44,17 +46,20 @@ from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable + distributed_backend = None distribution_lib = None elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 from keras.src.backend.numpy.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None elif backend() == "openvino": from keras.src.backend.openvino import * # noqa: F403 from keras.src.backend.openvino.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None else: raise ValueError(f"Unable to import backend : {backend()}") diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 89ac0fa71c8c..0a275fb70cf1 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -1,5 +1,6 @@ from keras.src.backend.config import is_nnx_enabled from keras.src.backend.jax import core +from keras.src.backend.jax import distributed_backend from keras.src.backend.jax import distribution_lib from keras.src.backend.jax import image from keras.src.backend.jax import linalg diff --git a/keras/src/backend/jax/distributed_backend.py b/keras/src/backend/jax/distributed_backend.py new file mode 100644 index 000000000000..e767793a2b40 --- /dev/null +++ b/keras/src/backend/jax/distributed_backend.py @@ -0,0 +1,95 @@ +import jax +import jax.lax as lax + + +def get_device_info(): + """Retrieves information about the available JAX devices. + + This function queries the JAX backend to identify the type and number + of available computational devices (e.g., CPU, GPU, TPU). + + Returns: + dict: A dictionary containing the backend name ('jax'), a list of + device string representations, and the total count of devices. + """ + available_devices = jax.devices() + return { + "backend": "jax", + "devices": [str(d) for d in available_devices], + "device_count": len(available_devices), + } + + +def is_multi_device_capable(): + """Checks if more than one JAX device is available for computation. + + Returns: + bool: True if the local JAX environment has more than one device, + False otherwise. + """ + return jax.local_device_count() > 1 + + +def get_communication_ops(): + """Provides a dictionary of JAX collective communication operations. + + Returns: + dict: A dictionary mapping operation names (e.g., 'all_reduce') to their + corresponding JAX implementation functions. + """ + + def all_reduce(x, op="sum", axis_name="model"): + """Reduces a tensor across a device mesh axis using a collective. + + This function assumes it is called within a `pjit` context that has a + device mesh with the specified `axis_name`. It performs a collective + reduction operation (like sum or mean) across all devices mapped to + that axis. + + Args: + x (jax.Array): The input JAX array (tensor) on the local device. + op (str, optional): The reduction operation to perform. Supported + values are 'sum' and 'mean'. Defaults to 'sum'. + axis_name (str, optional): The name of the mapped axis in the device + mesh over which to communicate. Defaults to 'model'. + + Returns: + jax.Array: The reduced JAX array, which is identical across all + devices participating in the reduction. + """ + if op == "sum": + return lax.psum(x, axis_name=axis_name) + elif op == "mean": + return lax.pmean(x, axis_name=axis_name) + else: + raise ValueError( + f"Unsupported reduction operation: {op}. " + "Supported options are 'sum' and 'mean'." + ) + + def all_gather(x, axis, axis_name="model"): + """Gathers and concatenates tensors from all devices across a mesh axis. + + This function assumes it is called within a `pjit` context. It takes + the local shard `x` from each device along the `axis_name` of the mesh + and concatenates them along the specified tensor `axis` to form a + single, larger tensor that is then replicated on all participating + devices. + + Args: + x (jax.Array): The input JAX array (tensor) shard on local device. + axis (int): The tensor axis along which to concatenate the gathered + shards. + axis_name (str, optional): The name of the mesh axis to gather + from. Defaults to 'model'. + + Returns: + jax.Array: The full, gathered JAX array, which is identical across + all devices participating in the gather. + """ + return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True) + + return { + "all_reduce": all_reduce, + "all_gather": all_gather, + } diff --git a/keras/src/backend/jax/distributed_backend_test.py b/keras/src/backend/jax/distributed_backend_test.py new file mode 100644 index 000000000000..43313ec5eba7 --- /dev/null +++ b/keras/src/backend/jax/distributed_backend_test.py @@ -0,0 +1,96 @@ +import os + +os.environ["JAX_PLATFORM_NAME"] = "cpu" +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + +import jax +import jax.numpy as jnp +import pytest + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.backend import distributed_backend + + +@pytest.mark.skipif( + backend.backend() != "jax" or jax.device_count() < 2, + reason="Test requires JAX backend and at least 2 devices", +) +class TestJaxDistributedFunctions(testing.TestCase): + """Unit tests for the JAX distributed backend functions.""" + + def setUp(self): + """Set up common variables for the tests.""" + super().setUp() + self.comm_ops = distributed_backend.get_communication_ops() + self.devices = jax.devices() + self.world_size = len(self.devices) + + def test_get_device_info(self): + """Test retrieving device information from the JAX backend.""" + info = distributed_backend.get_device_info() + self.assertEqual(info["backend"], "jax") + self.assertIsInstance(info["devices"], list) + self.assertEqual(info["device_count"], self.world_size) + self.assertEqual(self.world_size, 8) + + def test_is_multi_device_capable(self): + """Test the boolean check for multi-device capability.""" + self.assertTrue(distributed_backend.is_multi_device_capable()) + + def test_ops_raise_error_outside_parallel_context(self): + """Verify that communication ops fail when not in pmap/pjit context.""" + x = ops.array([1.0, 2.0]) + with self.assertRaisesRegex(NameError, "unbound axis name: model"): + self.comm_ops["all_reduce"](x) + + def test_all_reduce_sums_inputs_in_pmap(self): + """Tests that all_reduce with sum works correctly in pmap context.""" + x_reduce = ops.array([[1.0, 2.0], [3.0, 4.0]]) + sharded_reduce_input = jnp.stack([x_reduce] * self.world_size) + + pmapped_reduce = jax.pmap( + lambda x: self.comm_ops["all_reduce"]( + x, op="sum", axis_name="data" + ), + axis_name="data", + ) + reduced_result = pmapped_reduce(sharded_reduce_input) + + expected_reduce = ops.multiply(x_reduce, float(self.world_size)) + self.assertAllClose(reduced_result[0], expected_reduce) + + def test_all_reduce_averages_inputs_in_pmap(self): + """Tests that all_reduce with mean works correctly in pmap context.""" + x_reduce = ops.array([[1.0, 2.0], [3.0, 4.0]]) + sharded_reduce_input = jnp.stack( + [x_reduce + i for i in range(self.world_size)] + ) + + pmapped_reduce = jax.pmap( + lambda x: self.comm_ops["all_reduce"]( + x, op="mean", axis_name="data" + ), + axis_name="data", + ) + reduced_result = pmapped_reduce(sharded_reduce_input) + + expected_reduce = jnp.mean(sharded_reduce_input, axis=0) + self.assertAllClose(reduced_result[0], expected_reduce) + + def test_all_gather_collects_inputs_in_pmap(self): + """Tests that all_gather correctly collects inputs from all devices.""" + x_gather = jnp.arange(self.world_size * 2, dtype="float32").reshape( + (self.world_size, 2) + ) + + pmapped_gather = jax.pmap( + lambda x: self.comm_ops["all_gather"](x, axis=0, axis_name="data"), + axis_name="data", + ) + gathered_result = pmapped_gather(x_gather) + + self.assertAllClose( + gathered_result[0].reshape(x_gather.shape), x_gather + ) diff --git a/keras/src/distribution/tensor_parallel/tensor_layout.py b/keras/src/distribution/tensor_parallel/tensor_layout.py new file mode 100644 index 000000000000..bf80b45e7e82 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_layout.py @@ -0,0 +1,169 @@ +import keras + + +class LayoutAction: + """Abstract base class for actions that transform tensors for distribution. + + A LayoutAction defines a rule for how a single tensor should be physically + represented across multiple devices. It includes a forward operation + (`__call__`) to shard the tensor and a reverse operation (`undo`) + to reconstruct it.""" + + def __call__(self, tensor, rank): + """Applies the distribution action to a tensor for a specific worker. + + Args: + tensor: The input tensor to be distributed. + rank: The integer rank of the current worker/device. + + Raises: + NotImplementedError: This is an abstract method and must be + implemented by subclasses. + + Returns: + A shard or transformation of the input tensor specific to the given + rank. + """ + raise NotImplementedError + + def undo(self, tensors): + """Reverses the distribution action, reconstructing the original tensor. + + Args: + tensors: A sequence of tensor shards, one from each worker. + + Raises: + NotImplementedError: This is an abstract method and must be + implemented by subclasses. + + Returns: + The reconstructed, single tensor. + """ + raise NotImplementedError + + +class _ConcatenateMixin: + """A mixin class providing a common `undo` method via concatenation. + + This class is intended to be used as a mixin for `LayoutAction` subclasses + that can be undone by simple concatenation along a specified axis. + """ + + def undo(self, tensors): + """Concatenates a sequence of tensors to reconstruct original tensor. + + Args: + tensors: A sequence of tensor shards, one from each worker. + + Returns: + The single tensor reconstructed by concatenating the shards. + """ + if self.dim == -1: + dim = keras.ops.ndim(tensors[0]) - 1 + else: + dim = self.dim + return keras.ops.concatenate(tensors, axis=dim) + + +class Split(_ConcatenateMixin, LayoutAction): + """Splits a tensor into shards along a specified dimension. + + This is an internal utility used by a higher-level distribution API. + It implements sharding by slicing a tensor along one of its axes. + It handles cases where the dimension size is not perfectly divisible by the + number of workers by distributing the remainder elements one by one to the + first few workers. + + The `undo` operation is provided by the `_ConcatenateMixin`. + """ + + def __init__(self, world_size, dim, sharding_type="auto"): + """Initializes the Split action. + + Args: + world_size: The total number of workers/shards. + dim: The dimension along which to split the tensor. If -1, the + last dimension is used. + sharding_type: If `dim` is -1, this can be 'row' (dim=0) or + 'column' (dim=1) to infer the split axis for 2D tensors. + Defaults to "auto". + """ + super().__init__() + self.world_size = world_size + self.dim = dim + self.sharding_type = sharding_type + + if dim == -1 and sharding_type != "auto": + if sharding_type == "row": + self.dim = 0 + elif sharding_type == "column": + self.dim = 1 + + def __call__(self, tensor, rank): + """Splits the tensor and returns the shard corresponding to the rank. + + This method calculates the correct slice of the tensor for a given + worker rank, handling uneven distributions gracefully. + + Args: + tensor: The full tensor to be sharded. + rank: The rank of the worker for which to get the shard. + + Returns: + A tensor shard corresponding to the given rank. + """ + if self.dim == -1: + dim = keras.ops.ndim(tensor) - 1 + else: + dim = self.dim + + total_size = tensor.shape[dim] + split_size = total_size // self.world_size + remainder = total_size % self.world_size + + start_idx = rank * split_size + min(rank, remainder) + end_idx = start_idx + split_size + (1 if rank < remainder else 0) + + slices = [slice(None)] * keras.ops.ndim(tensor) + slices[dim] = slice(start_idx, end_idx) + return tensor[tuple(slices)] + + +class LayoutMap: + """A mapping that defines layout rules for model states and outputs. + + This is an internal configuration object used to hold layout rules for + how model variables and layer outputs should be distributed across a set + of devices. It acts as a container for `LayoutAction` instances. + + Attributes: + state_rules: A dictionary mapping variable names or patterns to + `LayoutAction` instances. + output_rules: A dictionary mapping layer output names or + patterns to `LayoutAction` instances. + """ + + def __init__(self, state_rules, output_rules): + """Initializes the LayoutMap. + + Args: + state_rules: A dictionary of distribution rules for model states. + output_rules: A dictionary of distribution rules for model outputs. + """ + self.state_rules = state_rules + self.output_rules = output_rules + + def create_collective_ops(self, devices): + """Creates the necessary collective communication operations. + + This method is a placeholder for backend-specific logic that would + translate the layout rules into actual communication primitives + (e.g., all-gather, reduce-scatter). + + Args: + devices: A sequence of device identifiers. + + Returns: + The `LayoutMap` instance itself, allowing for method chaining. + """ + return self diff --git a/keras/src/distribution/tensor_parallel/tensor_layout_test.py b/keras/src/distribution/tensor_parallel/tensor_layout_test.py new file mode 100644 index 000000000000..1135cf3b24dc --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_layout_test.py @@ -0,0 +1,163 @@ +import pytest + +import keras +from keras.src import backend +from keras.src import testing +from keras.src.distribution.tensor_parallel.tensor_layout import LayoutAction +from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap +from keras.src.distribution.tensor_parallel.tensor_layout import Split + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="Test requires JAX backend and at least 2 devices", +) +class LayoutTest(testing.TestCase): + """Test suite for tensor layout actions and mappings.""" + + def test_layout_action_abstract_methods_raise_error(self): + action = LayoutAction() + with self.assertRaises(NotImplementedError): + action(tensor=None, rank=0) + with self.assertRaises(NotImplementedError): + action.undo(tensors=None) + + # --- Split Action Tests --- + + def test_split_with_even_division(self): + """Tests splitting a tensor that divides evenly among workers.""" + world_size = 4 + # Create a tensor of shape (8, 2) + tensor = keras.ops.reshape( + keras.ops.arange(16, dtype="float32"), (8, 2) + ) + action = Split(world_size=world_size, dim=0) + + # Expected shard for rank 0 has shape (2, 2) + expected_shard_0 = keras.ops.array([[0.0, 1.0], [2.0, 3.0]]) + # Expected shard for rank 2 has shape (2, 2) + expected_shard_2 = keras.ops.array([[8.0, 9.0], [10.0, 11.0]]) + + shard_0 = action(tensor, rank=0) + shard_2 = action(tensor, rank=2) + + self.assertAllClose(shard_0, expected_shard_0) + self.assertAllClose(shard_2, expected_shard_2) + self.assertEqual(shard_0.shape, (2, 2)) + + def test_split_with_uneven_division(self): + """Tests splitting a tensor where remainder is distributed correctly.""" + world_size = 3 + # Create a tensor of shape (10, 1). 10 / 3 = 3 with remainder 1. + tensor = keras.ops.reshape( + keras.ops.arange(10, dtype="float32"), (10, 1) + ) + action = Split(world_size=world_size, dim=0) + + # Rank 0 should get 3 + 1 = 4 rows. + shard_0 = action(tensor, rank=0) + self.assertEqual(shard_0.shape, (4, 1)) + self.assertAllClose( + shard_0, keras.ops.array([[0.0], [1.0], [2.0], [3.0]]) + ) + + # Rank 1 should get 3 rows. + shard_1 = action(tensor, rank=1) + self.assertEqual(shard_1.shape, (3, 1)) + self.assertAllClose(shard_1, keras.ops.array([[4.0], [5.0], [6.0]])) + + # Rank 2 should get 3 rows. + shard_2 = action(tensor, rank=2) + self.assertEqual(shard_2.shape, (3, 1)) + self.assertAllClose(shard_2, keras.ops.array([[7.0], [8.0], [9.0]])) + + def test_split_and_undo_cycle_even(self): + """Tests the splitting and reconstructing of evenly divisible tensor.""" + world_size = 2 + original_tensor = keras.ops.reshape( + keras.ops.arange(12, dtype="float32"), (6, 2) + ) + action = Split(world_size=world_size, dim=0) + + # Create all shards + shards = [action(original_tensor, rank=i) for i in range(world_size)] + + # Reconstruct the tensor + reconstructed_tensor = action.undo(shards) + + self.assertAllClose(original_tensor, reconstructed_tensor) + + def test_split_and_undo_cycle_uneven(self): + """Tests full cycle for an unevenly distributed tensor.""" + world_size = 4 + # 11 / 4 = 2 with a remainder of 3. + original_tensor = keras.ops.reshape( + keras.ops.arange(22, dtype="float32"), (11, 2) + ) + action = Split(world_size=world_size, dim=0) + + shards = [action(original_tensor, rank=i) for i in range(world_size)] + + # Verify shard shapes: first 3 get 2+1=3 rows, last one gets 2. + self.assertEqual(shards[0].shape, (3, 2)) + self.assertEqual(shards[1].shape, (3, 2)) + self.assertEqual(shards[2].shape, (3, 2)) + self.assertEqual(shards[3].shape, (2, 2)) + + reconstructed_tensor = action.undo(shards) + self.assertAllClose(original_tensor, reconstructed_tensor) + + def test_split_last_dimension_with_undo(self): + """Tests splitting on the last dimension using dim=-1.""" + world_size = 3 + original_tensor = keras.ops.reshape( + keras.ops.arange(30, dtype="float32"), (2, 5, 3) + ) + action = Split(world_size=world_size, dim=-1) + + shards = [action(original_tensor, rank=i) for i in range(world_size)] + + # Each shard should have the last dimension split. + self.assertEqual(shards[0].shape, (2, 5, 1)) + self.assertEqual(shards[1].shape, (2, 5, 1)) + self.assertEqual(shards[2].shape, (2, 5, 1)) + + reconstructed_tensor = action.undo(shards) + self.assertAllClose(original_tensor, reconstructed_tensor) + + def test_split_with_sharding_type_hint(self): + """Tests using 'row' and 'column' sharding hints for 2D tensors.""" + world_size = 2 + tensor = keras.ops.reshape( + keras.ops.arange(16, dtype="float32"), (4, 4) + ) + + # Row sharding should split along axis 0 + action_row = Split(world_size=world_size, dim=-1, sharding_type="row") + shard_row_0 = action_row(tensor, rank=0) + self.assertAllClose(shard_row_0, tensor[:2, :]) + self.assertEqual(action_row.dim, 0) + + # Column sharding should split along axis 1 + action_col = Split( + world_size=world_size, dim=-1, sharding_type="column" + ) + shard_col_0 = action_col(tensor, rank=0) + self.assertAllClose(shard_col_0, tensor[:, :2]) + self.assertEqual(action_col.dim, 1) + + # --- LayoutMap Tests --- + + def test_layout_map_initialization_and_methods(self): + """Tests basic initialization and method behavior of LayoutMap class.""" + state_rules = {"kernel": Split(world_size=2, dim=0)} + output_rules = {"output": Split(world_size=2, dim=-1)} + + layout_map = LayoutMap(state_rules, output_rules) + + self.assertIs(layout_map.state_rules["kernel"], state_rules["kernel"]) + self.assertIs(layout_map.output_rules["output"], output_rules["output"]) + + self.assertIs( + layout_map.create_collective_ops(devices=["cpu:0"]), layout_map + )