diff --git a/python/ray/air/result.py b/python/ray/air/result.py index 9b911a563233..44d8207453ed 100644 --- a/python/ray/air/result.py +++ b/python/ray/air/result.py @@ -10,6 +10,7 @@ import pyarrow import ray +from ray._private.dict import unflattened_lookup from ray.air.constants import ( EXPR_ERROR_PICKLE_FILE, EXPR_PROGRESS_FILE, @@ -272,7 +273,9 @@ def get_best_checkpoint( op = max if mode == "max" else min valid_checkpoints = [ - ckpt_info for ckpt_info in self.best_checkpoints if metric in ckpt_info[1] + ckpt_info + for ckpt_info in self.best_checkpoints + if unflattened_lookup(metric, ckpt_info[1], default=None) is not None ] if not valid_checkpoints: @@ -281,4 +284,4 @@ def get_best_checkpoint( f"You may choose from the following metrics: {self.metrics.keys()}." ) - return op(valid_checkpoints, key=lambda x: x[1][metric])[0] + return op(valid_checkpoints, key=lambda x: unflattened_lookup(metric, x[1]))[0] diff --git a/python/ray/train/v2/tests/test_result.py b/python/ray/train/v2/tests/test_result.py index d6c05008660d..dbc335261bd8 100644 --- a/python/ray/train/v2/tests/test_result.py +++ b/python/ray/train/v2/tests/test_result.py @@ -116,6 +116,87 @@ def test_get_best_checkpoint(): ) +def test_get_best_checkpoint_nested_metrics(): + """Test that get_best_checkpoint works with nested metric dictionaries.""" + # Test with nested metric structure + res = Result( + metrics={}, + checkpoint=None, + error=None, + path="/bucket/path", + best_checkpoints=[ + ( + Checkpoint("/bucket/path/ckpt0"), + { + "iter": 0, + "env_runners": {"episode_return_mean": 100.0, "num_episodes": 10}, + }, + ), + ( + Checkpoint("/bucket/path/ckpt1"), + { + "iter": 1, + "env_runners": {"episode_return_mean": 200.0, "num_episodes": 10}, + }, + ), + ( + Checkpoint("/bucket/path/ckpt2"), + { + "iter": 2, + "env_runners": {"episode_return_mean": 300.0, "num_episodes": 10}, + }, + ), + ( + Checkpoint("/bucket/path/ckpt3"), + { + "iter": 3, + "env_runners": {"episode_return_mean": 400.0, "num_episodes": 10}, + }, + ), + ], + ) + + # Test max mode with nested metric + assert ( + res.get_best_checkpoint( + metric="env_runners/episode_return_mean", mode="max" + ).path + == "/bucket/path/ckpt3" + ) + + # Test min mode with nested metric + assert ( + res.get_best_checkpoint( + metric="env_runners/episode_return_mean", mode="min" + ).path + == "/bucket/path/ckpt0" + ) + + # Test that flat keys still work (backwards compatibility) + res_flat = Result( + metrics={}, + checkpoint=None, + error=None, + path="/bucket/path", + best_checkpoints=[ + ( + Checkpoint("/bucket/path/ckpt0"), + {"iter": 0, "env_runners/episode_return_mean": 100.0}, + ), + ( + Checkpoint("/bucket/path/ckpt1"), + {"iter": 1, "env_runners/episode_return_mean": 200.0}, + ), + ], + ) + assert ( + res_flat.get_best_checkpoint( + metric="env_runners/episode_return_mean", mode="max" + ).path + == "/bucket/path/ckpt1" + ) + + @pytest.mark.parametrize("path_type", ["str", "PathLike"]) @pytest.mark.parametrize("pass_storage_filesystem", [True, False]) @pytest.mark.parametrize("trailing_slash", [False, True])