Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/ray/train/v2/_internal/execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def _upload_checkpoint(
persisted_checkpoint = checkpoint_upload_fn(
checkpoint, checkpoint_dir_name
)
if not persisted_checkpoint:
raise ValueError("checkpoint_upload_fn must return a checkpoint")
else:
persisted_checkpoint = self.storage_context.persist_current_checkpoint(
checkpoint, checkpoint_dir_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,28 @@ def train_fn():
}


def test_checkpoint_upload_fn_returns_checkpoint():
def train_fn():
if ray.train.get_context().get_world_rank() == 0:
with create_dict_checkpoint({}) as checkpoint:
ray.train.report(
metrics={},
checkpoint=checkpoint,
checkpoint_upload_fn=lambda x, y: None,
)
else:
ray.train.report(metrics={}, checkpoint=None)

trainer = DataParallelTrainer(
train_fn,
scaling_config=ScalingConfig(num_workers=2),
)
with pytest.raises(
WorkerGroupError, match="checkpoint_upload_fn must return a checkpoint"
):
trainer.fit()


def test_get_all_reported_checkpoints_all_consistency_modes():
signal_actor = create_remote_signal_actor(ray).remote()

Expand Down