diff --git a/python/ray/train/v2/_internal/execution/context.py b/python/ray/train/v2/_internal/execution/context.py index 37905d8d7be8..337e6cadef2c 100644 --- a/python/ray/train/v2/_internal/execution/context.py +++ b/python/ray/train/v2/_internal/execution/context.py @@ -258,6 +258,12 @@ def _upload_checkpoint( persisted_checkpoint = checkpoint_upload_fn( checkpoint, checkpoint_dir_name ) + if persisted_checkpoint is None or not isinstance( + persisted_checkpoint, ray.train.Checkpoint + ): + raise ValueError( + "checkpoint_upload_fn must return a `ray.train.Checkpoint`." + ) else: persisted_checkpoint = self.storage_context.persist_current_checkpoint( checkpoint, checkpoint_dir_name diff --git a/python/ray/train/v2/tests/test_async_checkpointing_validation.py b/python/ray/train/v2/tests/test_async_checkpointing_validation.py index d438274984ee..a5dc9f896eb8 100644 --- a/python/ray/train/v2/tests/test_async_checkpointing_validation.py +++ b/python/ray/train/v2/tests/test_async_checkpointing_validation.py @@ -369,6 +369,26 @@ def train_fn(): } +def test_checkpoint_upload_fn_returns_checkpoint(): + def train_fn(): + with create_dict_checkpoint({}) as checkpoint: + ray.train.report( + metrics={}, + checkpoint=checkpoint, + checkpoint_upload_fn=lambda x, y: None, + ) + + trainer = DataParallelTrainer( + train_fn, + scaling_config=ScalingConfig(num_workers=1), + ) + with pytest.raises( + WorkerGroupError, + match="checkpoint_upload_fn must return a `ray.train.Checkpoint`", + ): + trainer.fit() + + def test_get_all_reported_checkpoints_all_consistency_modes(): signal_actor = create_remote_signal_actor(ray).remote()