diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index 6b5bf37314c0..d3989c0d08fc 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -1,4 +1,14 @@ -"""Utilities for distribution strategy with JAX backend.""" +"""Utilities for distribution strategy with JAX backend. + +This file contains the core JAX distribution primitives from Keras, +along with higher-level device management and auto-configuration utilities. +This version does not use try-except blocks for error handling. +""" + +import logging +from typing import Dict +from typing import List +from typing import Optional import jax import numpy as np @@ -8,6 +18,8 @@ from keras.src.utils import jax_utils from keras.src.utils import rng_utils +logger = logging.getLogger(__name__) + def list_devices(device_type=None): """Return all the available devices based on the device type. @@ -27,6 +39,153 @@ def list_devices(device_type=None): return [f"{device.platform}:{device.id}" for device in jax_devices] +def get_device_info(device_id: str) -> Dict[str, any]: + """ + Get detailed information about a specific device. + + Args: + device_id: Device identifier (e.g., 'gpu:0', 'tpu:0', 'cpu:0') + + Returns: + Dictionary containing device information + """ + device_info = { + "id": device_id, + "type": None, + "index": None, + "memory": None, + "capabilities": None, + } + + device_type, device_index = device_id.split(":") + device_info["type"] = device_type.upper() + device_info["index"] = int(device_index) + + return device_info + + +def get_best_devices(count: int = 1) -> List[str]: + """ + Get the best available devices for tensor parallelism. + + Args: + count: Number of devices needed + + Returns: + List of best device identifiers + """ + all_devices = list_devices() + + if count <= 0: + return [] + + if count > len(all_devices): + logger.warning( + f"Requested {count} devices but only {len(all_devices)} available" + ) + count = len(all_devices) + + return all_devices[:count] + + +def get_device_backend(device_type: str) -> str: + """ + Get the recommended backend for a device type. + + Args: + device_type: Device type ('tpu', 'gpu', 'cpu') + + Returns: + Recommended backend name + """ + backend_mapping = {"tpu": "jax", "gpu": "jax", "cpu": "jax"} + + return backend_mapping.get(device_type.lower(), "jax") + + +def validate_device_placement(device_id: str) -> bool: + """ + Validate if a device can be used for tensor operations. + + Args: + device_id: Device identifier + + Returns: + True if device is valid and available + """ + all_devices = list_devices() + return device_id in all_devices + + +def get_device_memory_info(device_id: str) -> Optional[Dict[str, any]]: + """ + Get memory information for a device (if available). + + Args: + device_id: Device identifier + + Returns: + Memory information dictionary or None if not available + """ + if device_id.startswith("gpu:"): + return { + "type": "GPU", + "index": int(device_id.split(":")[1]), + "memory": "Available", + } + elif device_id.startswith("tpu:"): + return { + "type": "TPU", + "index": int(device_id.split(":")[1]), + "memory": "TPU Memory", + } + elif device_id.startswith("cpu:"): + return { + "type": "CPU", + "index": int(device_id.split(":")[1]), + "memory": "System RAM", + } + + return None + + +def auto_configure_tensor_parallel( + world_size: int = None, backend: str = None +) -> Dict[str, any]: + """ + Automatically configure tensor parallelism with the best available devices. + + Args: + world_size: Number of devices to use (if None, uses all available) + backend: Backend to use (if None, will be set to 'jax') + + Returns: + Configuration dictionary with devices, backend, and other settings + """ + all_devices = list_devices() + + if not all_devices: + raise RuntimeError("No devices available for tensor parallelism") + + if world_size is None: + world_size = len(all_devices) + else: + world_size = min(world_size, len(all_devices)) + + selected_devices = all_devices[:world_size] + + recommended_backend = "jax" + + config = { + "devices": selected_devices, + "world_size": world_size, + "backend": recommended_backend, + } + + logger.info(f"Auto-configured tensor parallelism: {config}") + return config + + def distribute_variable(value, layout): """Create a distributed variable for JAX. diff --git a/keras/src/distribution/distribution_lib.py b/keras/src/distribution/distribution_lib.py index 2daef40a2ed8..a20b629c1c18 100644 --- a/keras/src/distribution/distribution_lib.py +++ b/keras/src/distribution/distribution_lib.py @@ -39,6 +39,24 @@ def list_devices(device_type=None): return distribution_lib.list_devices(device_type) +@keras_export("keras.distribution.get_best_devices") +def get_best_devices(count): + """Return all the available devices based on the device type. + + Note: in a distributed setting, global devices are returned. + + Args: + device_type: string, one of `"cpu"`, `"gpu"` or `"tpu"`. + Defaults to `"gpu"` or `"tpu"` if available when + `device_type` is not provided. Otherwise + will return the `"cpu"` devices. + + Return: + List of devices that are available for distribute computation. + """ + return distribution_lib.get_best_devices(count) + + @keras_export("keras.distribution.initialize") def initialize(job_addresses=None, num_processes=None, process_id=None): """Initialize the distribution system for multi-host/process setting. @@ -896,3 +914,183 @@ def set_distribution(value): value: a `Distribution` instance. """ global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value) + + +@keras_export("keras.distribution.AutoTPDistribution") +class AutoTPDistribution(Distribution): + """A distribution strategy for automated tensor and data parallelism. + + This distribution strategy provides a high-level abstraction for combining + both data parallelism and tensor parallelism. It automatically shards Keras + model's layers across multiple devices (tensor parallelism) while also + distributing the input data across those devices (data parallelism). + + It uses a `DeviceMesh` to represent the grid of computational devices. If no + mesh is provided, it creates one using all available devices. The mesh must + have a 'data' axis for data sharding and a 'model' axis for model sharding. + + Internally, this class wraps the user-provided Keras `Model` with the + `TensorParallelKeras` utility to handle the model sharding. + + Args: + model: A `keras.Model` instance to be distributed. + device_mesh: (Optional) A `keras.distribution.DeviceMesh` instance. + If not provided, a `DeviceMesh` will be automatically created using + all available devices, arranging them for both data and model + parallelism. + auto_shard_dataset: (Optional) A boolean indicating whether to + automatically shard `tf.data.Dataset` instances across multiple + processes. Defaults to `True`. + + Attributes: + model: The wrapped, tensor-parallel `keras.Model` instance that is + ready for distributed training. + device_mesh: The `DeviceMesh` instance used for distribution. + + Raises: + RuntimeError: If no computational devices are found and `device_mesh` + is not provided. + ValueError: If the provided `device_mesh` does not have a 'data' axis. + + Example: + + ```python + # Create a simple Keras model + inputs = keras.Input(shape=(64,)) + x = keras.layers.Dense(128, activation="relu")(inputs) + outputs = keras.layers.Dense(10)(x) + model = keras.Model(inputs=inputs, outputs=outputs) + + # Create the distribution strategy with the model + # It will automatically use all available GPUs/TPUs + distribution = keras.distribution.AutoTPDistribution(model) + + # The distributed model is accessed via the .model attribute + distributed_model = distribution.model + + # Compile the model as usual + distributed_model.compile(optimizer="adam", loss="mse") + + # Prepare a dataset + input_data = np.random.rand(32, 64) + target_data = np.random.rand(32, 10) + + # Train the model + distributed_model.fit(input_data, target_data) + ``` + """ + + def __init__(self, model, device_mesh=None, auto_shard_dataset=True): + if device_mesh is None: + all_devices = list_devices() + if not all_devices: + raise RuntimeError("No computational devices found.") + device_mesh = DeviceMesh( + shape=(1, len(all_devices)), + axis_names=("data", "model"), + devices=all_devices, + ) + + if "data" not in device_mesh.axis_names: + raise ValueError( + "DeviceMesh for AutoTPDistribution must have a 'data' axis." + ) + batch_dim_name = "data" + + super().__init__(device_mesh, batch_dim_name, auto_shard_dataset) + + self._original_model = model + self._num_process = distribution_lib.num_processes() + self._process_id = distribution_lib.process_id() + self._is_multi_process = self._num_process > 1 + from keras.src.distribution.tensor_parallel.tensor_parallel import ( + TensorParallelKeras, + ) + + self.model = TensorParallelKeras( + model=self._original_model, + world_size=np.prod(self.device_mesh.shape), + device_ids=self.device_mesh.devices.flatten().tolist(), + ) + + def get_data_layout(self, data_shape): + data_shard_spec = [None] * len(data_shape) + data_shard_spec[0] = self.batch_dim_name + return TensorLayout(data_shard_spec, self.device_mesh) + + def get_variable_layout(self, variable): + warnings.warn( + "Variable layout is determined automatically within " + "AutoTPDistribution. This method will return a replicated layout." + ) + return TensorLayout([None] * len(variable.shape), self.device_mesh) + + def get_tensor_layout(self, path): + return None + + def distribute_dataset(self, dataset): + """Distributes the dataset across processes based on the device mesh.""" + if not self._is_multi_process or not self.auto_shard_dataset: + return dataset + + from keras.src.utils.module_utils import tensorflow as tf + + if not tf.available or not isinstance(dataset, tf.data.Dataset): + raise ValueError( + "Only `tf.data.Dataset` is supported for auto-sharding, " + f"got {type(dataset)}" + ) + + from tensorflow.python.data.experimental.ops import ( + distribute as tf_data_distribute, + ) + + global_batch_size = tf_data_distribute.compute_batch_size(dataset) + if global_batch_size.numpy() < 0: + raise ValueError( + "The batch size of the input dataset is unknown. " + "Please configure the batch size for the input dataset, " + "e.g., via `dataset.batch(batch_size)`" + ) + + mesh_batch_dim_index = self.device_mesh.axis_names.index( + self.batch_dim_name + ) + num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index] + + if num_model_replicas == 1: + return dataset.prefetch(tf.data.AUTOTUNE) + + num_model_replicas_per_process = num_model_replicas / self._num_process + if num_model_replicas_per_process >= 1: + if global_batch_size % self._num_process != 0: + raise ValueError( + "Global batch size must be divisible by the number of " + f"processes. `global_batch_size`={global_batch_size} and " + f"`num_process`={self._num_process}" + ) + per_process_batch_size = global_batch_size // self._num_process + distributed_dataset = dataset.rebatch(per_process_batch_size) + distributed_dataset = distributed_dataset.shard( + num_shards=self._num_process, + index=self._process_id, + ) + return distributed_dataset.prefetch(tf.data.AUTOTUNE) + else: + if global_batch_size % num_model_replicas != 0: + raise ValueError( + "Global batch size must be divisible by the number of " + f"replicas. `global_batch_size`={global_batch_size} and " + f"`num_model_replicas`={num_model_replicas}" + ) + per_replica_batch_size = global_batch_size // num_model_replicas + distributed_dataset = dataset.rebatch(per_replica_batch_size) + + processes_per_replica = self._num_process // num_model_replicas + data_shard_id = self._process_id // processes_per_replica + + distributed_dataset = distributed_dataset.shard( + num_shards=num_model_replicas, + index=data_shard_id, + ) + return distributed_dataset.prefetch(tf.data.AUTOTUNE) diff --git a/keras/src/distribution/distribution_lib_test.py b/keras/src/distribution/distribution_lib_test.py index 66f996b3fb68..f5599d4d5a04 100644 --- a/keras/src/distribution/distribution_lib_test.py +++ b/keras/src/distribution/distribution_lib_test.py @@ -1,16 +1,36 @@ """Test for distribution_lib.py.""" import os + +# FILE: keras/src/distribution/distribution_lib_test.py + + +# --- TOP-LEVEL ENVIRONMENT SETUP --- +# This MUST be at the top of the file, before any Keras/TF imports. +# It configures the environment for all tests in this file. +os.environ["KERAS_BACKEND"] = "jax" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" + +# --- Now continue with the rest of the imports --- +# ... and so on from unittest import mock import numpy as np import pytest import tensorflow as tf +import keras from keras.src import backend from keras.src import testing from keras.src.backend import distribution_lib as backend_dlib from keras.src.distribution import distribution_lib +from keras.src.distribution.distribution_lib import AutoTPDistribution + +try: + import keras_hub +except ImportError: + keras_hub = None @pytest.mark.skipif( @@ -535,3 +555,120 @@ def test_iter(self): # ValueError, "Cannot create sharding when device mesh is not set" # ): # backend_dlib._to_dtensor_layout(layout) + + +class AutoTPDistributionTest(testing.TestCase): + def setUp(self): + super().setUp() + self.devices = distribution_lib.list_devices() + if len(self.devices) < 2: + self.skipTest("This test requires at least 2 devices.") + inputs = keras.Input(shape=(4,), name="input_layer") + x = keras.layers.Dense(8, name="dense_1")(inputs) + outputs = keras.layers.Dense(2, name="dense_2")(x) + self.model = keras.Model(inputs, outputs) + + def test_init_with_explicit_device_mesh(self): + """Tests initialization with a user-provided DeviceMesh.""" + device_mesh = distribution_lib.DeviceMesh( + shape=(1, 2), axis_names=["data", "model"], devices=self.devices + ) + distribution = AutoTPDistribution(self.model, device_mesh=device_mesh) + + self.assertIs(distribution.device_mesh, device_mesh) + self.assertEqual(distribution.batch_dim_name, "data") + self.assertIsInstance( + distribution.model, + keras.src.distribution.tensor_parallel.tensor_parallel.TensorParallelKeras, + ) + self.assertEqual(distribution.model.world_size, 2) + + @mock.patch.object( + distribution_lib, + "list_devices", + return_value=[f"cpu:{i}" for i in range(2)], + ) + def test_init_without_device_mesh_for_auto_creation( + self, mock_list_devices + ): + """Tests the automatic creation of DeviceMesh when none is provided.""" + distribution = AutoTPDistribution(self.model, device_mesh=None) + mock_list_devices.assert_called_once() + + device_mesh = distribution.device_mesh + self.assertEqual(device_mesh.shape, (1, 2)) + self.assertEqual(device_mesh.axis_names, ("data", "model")) + self.assertEqual(distribution.batch_dim_name, "data") + self.assertEqual(distribution.model.world_size, 2) + + def test_init_raises_error_on_missing_data_axis(self): + """Ensures an error is raised if the DeviceMesh lacks a 'data' axis.""" + device_mesh = distribution_lib.DeviceMesh( + shape=(2,), axis_names=["model"], devices=self.devices + ) + with self.assertRaisesRegex(ValueError, "must have a 'data' axis"): + AutoTPDistribution(self.model, device_mesh=device_mesh) + + def test_get_data_layout(self): + """Verifies the layout for input data sharding.""" + distribution = AutoTPDistribution(self.model) + data_shape = (16, 4) + layout = distribution.get_data_layout(data_shape) + + self.assertEqual(layout.axes, ("data", None)) + self.assertIs(layout.device_mesh, distribution.device_mesh) + + def test_get_variable_layout_warns_and_returns_replicated(self): + """Verifies that variable layout is handled internally.""" + distribution = AutoTPDistribution(self.model) + dummy_variable = backend.Variable(initializer=np.zeros((8, 2))) + + with self.assertWarns(UserWarning) as w: + layout = distribution.get_variable_layout(dummy_variable) + + self.assertIn( + "Variable layout is determined automatically", + str(w.warnings[0].message), + ) + + self.assertEqual(layout.axes, (None, None)) + + def test_distribute_dataset_in_single_process_mode(self): + """Tests dataset distribution in a single-process environment.""" + distribution = AutoTPDistribution(self.model) + dataset = tf.data.Dataset.from_tensor_slices( + (np.zeros((16, 4)), np.zeros((16, 1))) + ) + + distributed_dataset = distribution.distribute_dataset(dataset) + self.assertIs(dataset, distributed_dataset) + + def test_full_compile_and_fit_integration(self): + """A test to ensure the distributed model can compile and train.""" + distribution = AutoTPDistribution(self.model) + + x_train = np.random.rand(16, 4).astype("float32") + y_train = np.random.randint(0, 2, size=(16, 1)) + + dist_model = distribution.model + + with distribution.scope(): + dist_model.compile( + optimizer=keras.optimizers.Adam(0.01), + loss=keras.losses.SparseCategoricalCrossentropy( + from_logits=True + ), + metrics=["accuracy"], + ) + + self.assertEqual(self.model.count_params(), dist_model.count_params()) + + history = dist_model.fit( + x_train, + y_train, + epochs=1, + batch_size=4, + verbose=0, + ) + self.assertIn("loss", history.history) + self.assertIn("accuracy", history.history) diff --git a/keras/src/distribution/tensor_parallel/tensor_parallel.py b/keras/src/distribution/tensor_parallel/tensor_parallel.py new file mode 100644 index 000000000000..ac0f5d82677b --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_parallel.py @@ -0,0 +1,437 @@ +""" +Tensor Parallel implementation for Keras 3.0 +Port of the PyTorch tensor_parallel library +""" + +import re + +import numpy as np + +import keras +from keras import ops +from keras.src.distribution import list_devices +from keras.src.distribution.tensor_parallel.autoconfig import ( + get_default_config_keras, +) +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + TensorParallelOptimizer, +) +from keras.src.distribution.tensor_parallel.parameter_sharding import ( + make_parameter_sharded_model, +) +from keras.src.models import Model + + +class TensorParallelKeras(Model): + """ + A Keras Model wrapper that implements tensor parallelism. + + This class takes a standard Keras model and shards its weights across + multiple devices, enabling the model to handle larger sizes than would fit + on a single device. It automatically handles the sharding, communication, + and coordination required for training and inference. + + Args: + model: The Keras model to be parallelized. + world_size (int, optional): The number of devices to parallelize across. + If None, it's auto-detected. Defaults to None. + device_ids (list, optional): A list of device IDs to use. If None, + they are auto-detected. Defaults to None. + distributed_backend (str, optional): The backend to use for distributed + communication. Defaults to "auto". + **kwargs: Additional arguments passed to the base `keras.Model`. + """ + + def __init__( + self, + model, + world_size=None, + device_ids=None, + distributed_backend="auto", + **kwargs, + ): + super().__init__(**kwargs) + + self._original_model = model + + if world_size is None: + world_size, device_ids = self._auto_detect_parallelism() + elif device_ids is None: + device_ids = self._auto_configure_devices( + world_size, distributed_backend + ) + + self.world_size = world_size + self.device_ids = device_ids + self.sharding_strategy = "auto" + self.distributed_backend = distributed_backend + + self.tensor_parallel_config = None + self.distributed = True + + self.sharded_models = [self._original_model] + + accel_devices = list_devices() + device_ids = list(self.check_device_ids(device_ids)) + + if accel_devices: + if len(accel_devices) >= world_size: + device_ids = accel_devices[:world_size] + else: + world_size = len(accel_devices) + device_ids = accel_devices[:world_size] + + if not device_ids: + device_ids = self._auto_configure_devices( + world_size, distributed_backend + ) + + if len(device_ids) != world_size: + device_ids = self._adjust_device_list(device_ids, world_size) + + self.devices = device_ids + self.world_size = world_size + + if self.world_size <= 1: + self.model_shards = [model] + self.distributed = False + if len(self.devices) == 1: + from keras import device + + with device(self.devices[0]): + self.model_shards[0] = model + self.built = True + self.assembled_model = self._original_model + return + + if self.tensor_parallel_config is None: + device_names = [str(d) for d in self.devices] + self.tensor_parallel_config = get_default_config_keras( + model, device_names + ) + config_with_ops = self.tensor_parallel_config.create_collective_ops( + self.devices + ) + self._is_multi_layer_model = len(model.layers) > 2 + self.model_shards = [] + self.modified_parameters_names = set() + + for rank, device_id in enumerate(self.devices): + shard, modified_parameters_names = make_parameter_sharded_model( + model, + config_with_ops, + rank=rank, + world_size=self.world_size, + device_id=device_id, + ) + self.model_shards.append(shard) + self.modified_parameters_names.update(modified_parameters_names) + + params_per_shard = [] + for i, shard in enumerate(self.model_shards): + total_params = sum(np.prod(p.shape) for p in shard.weights) + params_per_shard.append(int(total_params)) + + self.distributed_backend_name = distributed_backend + from keras.src.backend import distributed_backend + + self.distributed_backend = distributed_backend + + self.built = True + if self.distributed: + self.assembled_model = self.build_assembled_model() + else: + self.assembled_model = self._original_model + + @property + def variables(self): + """Returns a list of all unique variables from all model shards.""" + unique_vars = { + id(var): var + for shard in self.model_shards + for var in shard.variables + } + return list(unique_vars.values()) + + @property + def trainable_variables(self): + """Returns list of all unique trainable variables from model shards.""" + unique_vars = { + id(var): var + for shard in self.model_shards + for var in shard.trainable_variables + } + return list(unique_vars.values()) + + @property + def non_trainable_variables(self): + """Returns list of unique non-trainable variables from model shards.""" + unique_vars = { + id(var): var + for shard in self.model_shards + for var in shard.non_trainable_variables + } + return list(unique_vars.values()) + + @property + def weights(self): + """Returns a list of all unique weights from all model shards.""" + unique_vars = { + id(var): var for shard in self.model_shards for var in shard.weights + } + return list(unique_vars.values()) + + @property + def trainable_weights(self): + """Returns a list of all unique trainable weights from model shards.""" + unique_vars = { + id(var): var + for shard in self.model_shards + for var in shard.trainable_weights + } + return list(unique_vars.values()) + + @property + def non_trainable_weights(self): + """Returns list of unique non-trainable weights from model shards.""" + unique_vars = { + id(var): var + for shard in self.model_shards + for var in shard.non_trainable_weights + } + return list(unique_vars.values()) + + def _auto_detect_parallelism(self): + """Auto-detects the number of available devices and sets world size.""" + from keras.src.distribution import get_best_devices + + available_devices = list_devices() + world_size = len(available_devices) + + device_ids = get_best_devices(world_size) + + return world_size, device_ids + + def _adjust_device_list(self, device_ids, target_world_size): + """Adjusts the device list to match the target world size.""" + current_size = len(device_ids) + if current_size >= target_world_size: + return device_ids[:target_world_size] + + return list(device_ids) + [ + f"cpu:{i}" for i in range(current_size, target_world_size) + ] + + def _auto_configure_devices(self, world_size, distributed_backend): + """Automatically configures the devices to be used for parallelism.""" + available_devices = list_devices() + if available_devices: + devices = available_devices[:world_size] + return devices + else: + return ["cpu:0"] + + def check_device_ids(self, device_ids): + """Validates and normalizes a sequence of device IDs.""" + if device_ids is None: + device_ids = self._get_all_device_indices() + + return tuple(self.canonicalize_device(d) for d in device_ids) + + def _get_all_device_indices(self): + """Retrieves all available device indices from distribution library.""" + return list_devices() + + def build_assembled_model(self): + """ + Builds a single Keras Functional model that encapsulates tensor + parallel logic. + + This method creates unified model that takes original model's inputs, + distributes the computation across the sharded models, and assembles + the final output. This assembled model is JIT-compilation friendly. + + Returns: + A `keras.Model` instance representing the assembled parallel model. + """ + if not self.distributed: + return self._original_model + + input_layers = { + inp.name.split(":")[0]: keras.Input( + shape=inp.shape[1:], + dtype=inp.dtype, + name=inp.name.split(":")[0], + ) + for inp in self._original_model.inputs + } + + partial_outputs = [model(input_layers) for model in self.sharded_models] + + final_layer = self._original_model.layers[-1] + sharding_type = "unknown" + final_kernel_name = f"{final_layer.name}.kernel" + if hasattr(self._original_model, "name") and self._original_model.name: + final_kernel_name = ( + f"{self._original_model.name}.{final_kernel_name}" + ) + + for pattern, action in self.tensor_parallel_config.state_rules.items(): + if re.search(pattern, final_kernel_name): + if hasattr(action, "sharding_type"): + sharding_type = action.sharding_type + break + + if sharding_type == "column": + final_output = ops.concatenate(partial_outputs, axis=-1) + original_output_dim = self._original_model.output_shape[-1] + if final_output.shape[-1] != original_output_dim: + final_output = keras.layers.Lambda( + lambda x: x[..., :original_output_dim] + )(final_output) + elif sharding_type == "row": + if len(partial_outputs) > 1: + summed_output = keras.layers.Add()(partial_outputs) + else: + summed_output = partial_outputs[0] + + if final_layer.use_bias: + bias = final_layer.bias + final_output = keras.layers.Lambda( + lambda x: x - bias * (self.world_size - 1) + )(summed_output) + else: + final_output = summed_output + else: + final_output = partial_outputs[0] + + assembled_model = keras.Model( + inputs=list(input_layers.values()), outputs=final_output + ) + return assembled_model + + def canonicalize_device(self, device_spec): + """ + Converts a device specification to its canonical string form. + + Args: + device_spec: The device identifier (e.g., an int like 0, or a + string like "gpu:0", "cuda:0", "cpu"). + + Returns: + A string representing the canonical device name + (e.g., "gpu:0", "cpu"). + """ + if isinstance(device_spec, int): + if device_spec == -1: + return "cpu" + else: + return f"gpu:{device_spec}" + elif isinstance(device_spec, str): + if device_spec == "cpu": + return "cpu" + elif device_spec.startswith("gpu:"): + return device_spec + elif device_spec.startswith("cuda:"): + return f"gpu:{device_spec.split(':')[1]}" + else: + return device_spec + else: + return "cpu" + + def call(self, inputs, training=None, **kwargs): + """ + Defines the forward pass of the tensor-parallel model. + + This method delegates the call to the internal `assembled_model`, + which handles the distributed computation. + + Args: + inputs: Input tensors. + training (bool, optional): Indicates whether the model is in + training mode. Defaults to None. + **kwargs: Additional arguments. + + Returns: + The output tensor(s) of the model. + """ + return self.assembled_model(inputs, training=training, **kwargs) + + def compile(self, optimizer=None, loss=None, metrics=None, **kwargs): + """ + Configures the model for training. + + If an optimizer is provided and the model is distributed across more + than one device, it wraps the optimizer in a `TensorParallelOptimizer` + to coordinate gradients across all shards. + + Args: + optimizer: The optimizer instance. + loss: The loss function. + metrics: A list of metrics to be evaluated by the model. + **kwargs: Additional arguments passed to `keras.Model.compile`. + """ + if len(self.model_shards) > 1 and optimizer is not None: + backend_name = getattr(self, "distributed_backend_name", "auto") + + self.coordinated_optimizer = TensorParallelOptimizer( + optimizer, + self.world_size, + distributed_backend=backend_name, + tensor_parallel_config=self.tensor_parallel_config, + ) + self.coordinated_optimizer._shard_models = self.model_shards + + var_map = {} + assembled = getattr(self, "assembled_model", None) + assembled_vars = ( + assembled.variables if assembled is not None else [] + ) + + for a_var in assembled_vars: + key = getattr(a_var, "path", None) or a_var.name + suffix = key.split("/")[-1] + per_shard = [] + for shard in self.model_shards: + match = next( + (v for v in shard.variables if v.name.endswith(suffix)), + None, + ) + per_shard.append(match) + var_map[key] = per_shard + + self.coordinated_optimizer._shard_var_map = var_map + inner = getattr( + self.coordinated_optimizer, "coordinated_optimizer", None + ) + if inner is not None: + inner._shard_models = self.model_shards + inner._shard_var_map = var_map + + super().compile( + optimizer=self.coordinated_optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) + + else: + super().compile(optimizer, loss, metrics, **kwargs) + + def fit(self, x=None, y=None, **kwargs): + """ + Trains the model for a fixed number of epochs (iterations on a dataset). + + This method uses the standard Keras `fit` method, which correctly + handles the custom `train_step` implicitly managed by compiled model. + + Args: + x: Input data. + y: Target data. + **kwargs: Additional arguments passed to `keras.Model.fit`. + + Returns: + A `History` object. Its `history` attribute is a record of training + loss values and metric values at successive epochs. + """ + return super().fit(x, y, **kwargs) diff --git a/keras/src/distribution/tensor_parallel/tensor_parallel_test.py b/keras/src/distribution/tensor_parallel/tensor_parallel_test.py new file mode 100644 index 000000000000..e72490f2bd49 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_parallel_test.py @@ -0,0 +1,130 @@ +import numpy as np +import pytest + +import keras +from keras import layers +from keras.src import backend +from keras.src.distribution.tensor_parallel.tensor_parallel import ( + TensorParallelKeras, +) +from keras.src.testing import TestCase + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is for the JAX backend only.", +) +class TensorParallelKerasTest(TestCase): + """ + Test suite for the TensorParallelKeras class running on the JAX backend. + """ + + def setUp(self): + """Set up a reusable model and data for all tests.""" + super().setUp() + + inputs = keras.Input(shape=(64,), name="input_layer") + x = layers.Dense(128, activation="relu", name="dense_column_sharded")( + inputs + ) + outputs = layers.Dense(10, name="dense_row_sharded")(x) + self.original_model = keras.Model( + inputs=inputs, outputs=outputs, name="test_mlp" + ) + + self.input_data = np.random.rand(32, 64).astype("float32") + self.target_data = np.random.rand(32, 10).astype("float32") + + self.world_size = 2 + self.device_ids = [f"cpu:{i}" for i in range(self.world_size)] + + def test_initialization_and_sharding_verification(self): + """ + Tests if model is correctly initialized and parameter sharding occurs. + """ + tp_model = TensorParallelKeras( + self.original_model, + world_size=self.world_size, + device_ids=self.device_ids, + ) + + self.assertTrue(tp_model.distributed) + self.assertEqual(tp_model.world_size, self.world_size) + self.assertEqual(len(tp_model.model_shards), self.world_size) + + original_params = self.original_model.count_params() + shard_0_params = tp_model.model_shards[0].count_params() + + self.assertLess(shard_0_params, original_params) + + tp_model_total_params = sum(np.prod(v.shape) for v in tp_model.weights) + self.assertEqual(tp_model_total_params, original_params) + + def test_non_distributed_case_world_size_one(self): + """ + Tests the behavior when world_size is 1, ensuring it gracefully degrades + to a standard, non-distributed model. + """ + tp_model = TensorParallelKeras(self.original_model, world_size=1) + + self.assertFalse(tp_model.distributed) + self.assertEqual(tp_model.world_size, 1) + self.assertEqual(len(tp_model.model_shards), 1) + self.assertIs(tp_model.assembled_model, self.original_model) + + output = tp_model.predict(self.input_data, verbose=0) + self.assertEqual(output.shape, (32, 10)) + + def test_forward_pass_correctness(self): + """ + Tests if the output of the sharded model is numerically identical + to the original model. + """ + inputs = keras.Input(shape=(64,), name="input_layer") + x = layers.Dense( + 128, activation="relu", kernel_initializer="glorot_uniform" + )(inputs) + outputs = layers.Dense(10, kernel_initializer="glorot_uniform")(x) + original_model = keras.Model(inputs=inputs, outputs=outputs) + + input_data = np.random.rand(32, 64).astype("float32") + + original_output = original_model(input_data, training=False) + + tp_model = TensorParallelKeras( + original_model, + world_size=self.world_size, + device_ids=self.device_ids, + ) + + tp_output = tp_model(input_data, training=False) + + self.assertAllClose(original_output, tp_output, atol=1e-5, rtol=1e-5) + + def test_distributed_training_workflow(self): + """ + Tests if model can be compiled and trained for one step without errors. + """ + tp_model = TensorParallelKeras( + self.original_model, + world_size=self.world_size, + device_ids=self.device_ids, + ) + + tp_model.compile( + optimizer=keras.optimizers.Adam(learning_rate=0.01), + loss="mse", + ) + + self.assertTrue(hasattr(tp_model, "coordinated_optimizer")) + + history = tp_model.fit( + self.input_data, + self.target_data, + epochs=1, + batch_size=16, + verbose=0, + ) + + self.assertIn("loss", history.history) + self.assertIsNotNone(history.history["loss"][0])