Skip to content
Open
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras.src.distribution.distribution_lib import distribution as distribution
from keras.src.distribution.distribution_lib import initialize as initialize
from keras.src.distribution.distribution_lib import list_devices as list_devices
from keras.src.distribution.distribution_lib import process_id as process_id
from keras.src.distribution.distribution_lib import (
set_distribution as set_distribution,
)
3 changes: 3 additions & 0 deletions keras/api/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
1 change: 1 addition & 0 deletions keras/api/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras.src.distribution.distribution_lib import distribution as distribution
from keras.src.distribution.distribution_lib import initialize as initialize
from keras.src.distribution.distribution_lib import list_devices as list_devices
from keras.src.distribution.distribution_lib import process_id as process_id
from keras.src.distribution.distribution_lib import (
set_distribution as set_distribution,
)
2 changes: 1 addition & 1 deletion keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from keras.src.backend.config import is_nnx_enabled
from keras.src.backend.jax import core
from keras.src.backend.jax import distribution_lib
from keras.src.backend.jax import image
from keras.src.backend.jax import linalg
from keras.src.backend.jax import math
Expand All @@ -25,6 +24,7 @@
from keras.src.backend.jax.core import shape
from keras.src.backend.jax.core import stop_gradient
from keras.src.backend.jax.core import vectorized_map
from keras.src.backend.jax.distribution_lib import process_id
from keras.src.backend.jax.rnn import cudnn_ok
from keras.src.backend.jax.rnn import gru
from keras.src.backend.jax.rnn import lstm
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from keras.src.backend.numpy.core import random_seed_dtype
from keras.src.backend.numpy.core import shape
from keras.src.backend.numpy.core import vectorized_map
from keras.src.backend.numpy.distribution_lib import process_id
from keras.src.backend.numpy.rnn import cudnn_ok
from keras.src.backend.numpy.rnn import gru
from keras.src.backend.numpy.rnn import lstm
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/numpy/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Utilities for distribution strategy with NumPy backend."""


def process_id():
"""Return the current process ID for the distribution setting."""
return 0
1 change: 1 addition & 0 deletions keras/src/backend/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras.src.backend.openvino.core import random_seed_dtype
from keras.src.backend.openvino.core import shape
from keras.src.backend.openvino.core import vectorized_map
from keras.src.backend.openvino.distribution_lib import process_id
from keras.src.backend.openvino.rnn import cudnn_ok
from keras.src.backend.openvino.rnn import gru
from keras.src.backend.openvino.rnn import lstm
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/openvino/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Utilities for distribution strategy with OpenVINO backend."""


def process_id():
"""Return the current process ID for the distribution setting."""
return 0
2 changes: 1 addition & 1 deletion keras/src/backend/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from keras.src.backend.tensorflow import core
from keras.src.backend.tensorflow import distribution_lib
from keras.src.backend.tensorflow import image
from keras.src.backend.tensorflow import linalg
from keras.src.backend.tensorflow import math
Expand All @@ -24,6 +23,7 @@
from keras.src.backend.tensorflow.core import shape
from keras.src.backend.tensorflow.core import stop_gradient
from keras.src.backend.tensorflow.core import vectorized_map
from keras.src.backend.tensorflow.distribution_lib import process_id
from keras.src.backend.tensorflow.rnn import cudnn_ok
from keras.src.backend.tensorflow.rnn import gru
from keras.src.backend.tensorflow.rnn import lstm
Expand Down
10 changes: 10 additions & 0 deletions keras/src/backend/tensorflow/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,13 @@ def _to_backend_layout(tensor_layout):
]
dtensor_mesh = tensor_layout.device_mesh.backend_mesh
return dtensor.Layout(sharding_specs=sharding_specs, mesh=dtensor_mesh)


def process_id():
"""Return the current process ID for the distribution setting."""
try:
import tensorflow as tf

return tf.distribute.get_replica_context().replica_id_in_sync_group
except (ImportError, AttributeError, RuntimeError):
return 0
1 change: 1 addition & 0 deletions keras/src/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from keras.src.backend.torch.core import stop_gradient
from keras.src.backend.torch.core import to_torch_dtype
from keras.src.backend.torch.core import vectorized_map
from keras.src.backend.torch.distribution_lib import process_id
from keras.src.backend.torch.rnn import cudnn_ok
from keras.src.backend.torch.rnn import gru
from keras.src.backend.torch.rnn import lstm
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/torch/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Utilities for distribution strategy with PyTorch backend."""


def process_id():
"""Return the current process ID for the distribution setting."""
try:
import torch.distributed as dist

if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
except (ImportError, AttributeError):
return 0
1 change: 1 addition & 0 deletions keras/src/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
from keras.src.callbacks.monitor_callback import MonitorCallback
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
from keras.src.callbacks.progbar_logger import ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
from keras.src.callbacks.remote_monitor import RemoteMonitor
Expand Down
Loading