From f99d7a1709294a46d210a0a678812abd45e0430d Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Tue, 11 Nov 2025 10:22:14 +0000 Subject: [PATCH 1/4] [rllib, air, train] Add support for nested metrics for `Result.get_best_checkpoint` Signed-off-by: Mark Towers --- python/ray/air/result.py | 7 +- python/ray/train/v2/tests/test_result.py | 84 ++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 2 deletions(-) diff --git a/python/ray/air/result.py b/python/ray/air/result.py index 9b911a563233..44d8207453ed 100644 --- a/python/ray/air/result.py +++ b/python/ray/air/result.py @@ -10,6 +10,7 @@ import pyarrow import ray +from ray._private.dict import unflattened_lookup from ray.air.constants import ( EXPR_ERROR_PICKLE_FILE, EXPR_PROGRESS_FILE, @@ -272,7 +273,9 @@ def get_best_checkpoint( op = max if mode == "max" else min valid_checkpoints = [ - ckpt_info for ckpt_info in self.best_checkpoints if metric in ckpt_info[1] + ckpt_info + for ckpt_info in self.best_checkpoints + if unflattened_lookup(metric, ckpt_info[1], default=None) is not None ] if not valid_checkpoints: @@ -281,4 +284,4 @@ def get_best_checkpoint( f"You may choose from the following metrics: {self.metrics.keys()}." ) - return op(valid_checkpoints, key=lambda x: x[1][metric])[0] + return op(valid_checkpoints, key=lambda x: unflattened_lookup(metric, x[1]))[0] diff --git a/python/ray/train/v2/tests/test_result.py b/python/ray/train/v2/tests/test_result.py index d6c05008660d..73b0f94d2de8 100644 --- a/python/ray/train/v2/tests/test_result.py +++ b/python/ray/train/v2/tests/test_result.py @@ -116,6 +116,90 @@ def test_get_best_checkpoint(): ) +def test_get_best_checkpoint_nested_metrics(): + """Test that get_best_checkpoint works with nested metric dictionaries. + + Rllib uses nested metric structure like {"env_runners": {"episode_return_mean": value}} + """ + # Test with nested metric structure + res = Result( + metrics={}, + checkpoint=None, + error=None, + path="/bucket/path", + best_checkpoints=[ + ( + Checkpoint("/bucket/path/ckpt0"), + { + "iter": 0, + "env_runners": {"episode_return_mean": 100.0, "num_episodes": 10}, + }, + ), + ( + Checkpoint("/bucket/path/ckpt1"), + { + "iter": 1, + "env_runners": {"episode_return_mean": 200.0, "num_episodes": 10}, + }, + ), + ( + Checkpoint("/bucket/path/ckpt2"), + { + "iter": 2, + "env_runners": {"episode_return_mean": 300.0, "num_episodes": 10}, + }, + ), + ( + Checkpoint("/bucket/path/ckpt3"), + { + "iter": 3, + "env_runners": {"episode_return_mean": 400.0, "num_episodes": 10}, + }, + ), + ], + ) + + # Test max mode with nested metric + assert ( + res.get_best_checkpoint( + metric="env_runners/episode_return_mean", mode="max" + ).path + == "/bucket/path/ckpt3" + ) + + # Test min mode with nested metric + assert ( + res.get_best_checkpoint( + metric="env_runners/episode_return_mean", mode="min" + ).path + == "/bucket/path/ckpt0" + ) + + # Test that flat keys still work (backwards compatibility) + res_flat = Result( + metrics={}, + checkpoint=None, + error=None, + path="/bucket/path", + best_checkpoints=[ + ( + Checkpoint("/bucket/path/ckpt0"), + {"iter": 0, "env_runners/episode_return_mean": 100.0}, + ), + ( + Checkpoint("/bucket/path/ckpt1"), + {"iter": 1, "env_runners/episode_return_mean": 200.0}, + ), + ], + ) + assert ( + res_flat.get_best_checkpoint( + metric="env_runners/episode_return_mean", mode="max" + ).path + == "/bucket/path/ckpt1" + ) + + @pytest.mark.parametrize("path_type", ["str", "PathLike"]) @pytest.mark.parametrize("pass_storage_filesystem", [True, False]) @pytest.mark.parametrize("trailing_slash", [False, True]) From e3027e76fc598cf37adf04c08fb682218eab6494 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Tue, 11 Nov 2025 11:38:47 +0000 Subject: [PATCH 2/4] Fix RLlib spelling Signed-off-by: Mark Towers --- python/ray/train/v2/tests/test_result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/train/v2/tests/test_result.py b/python/ray/train/v2/tests/test_result.py index 73b0f94d2de8..260258af5970 100644 --- a/python/ray/train/v2/tests/test_result.py +++ b/python/ray/train/v2/tests/test_result.py @@ -119,7 +119,7 @@ def test_get_best_checkpoint(): def test_get_best_checkpoint_nested_metrics(): """Test that get_best_checkpoint works with nested metric dictionaries. - Rllib uses nested metric structure like {"env_runners": {"episode_return_mean": value}} + RLlib uses nested metric structure like {"env_runners": {"episode_return_mean": value}} """ # Test with nested metric structure res = Result( From 354900450745d32ea6287aa2847ccce8f8af000c Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Mon, 17 Nov 2025 20:48:46 +0000 Subject: [PATCH 3/4] Update python/ray/train/v2/tests/test_result.py Co-authored-by: Justin Yu Signed-off-by: Mark Towers --- python/ray/train/v2/tests/test_result.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/ray/train/v2/tests/test_result.py b/python/ray/train/v2/tests/test_result.py index 260258af5970..f4ebc95f7cb5 100644 --- a/python/ray/train/v2/tests/test_result.py +++ b/python/ray/train/v2/tests/test_result.py @@ -118,8 +118,6 @@ def test_get_best_checkpoint(): def test_get_best_checkpoint_nested_metrics(): """Test that get_best_checkpoint works with nested metric dictionaries. - - RLlib uses nested metric structure like {"env_runners": {"episode_return_mean": value}} """ # Test with nested metric structure res = Result( From d3813bb809490929186f425c7c1fc843c0be010a Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Mon, 17 Nov 2025 23:02:56 +0000 Subject: [PATCH 4/4] pre-commit Signed-off-by: Mark Towers --- python/ray/train/v2/tests/test_result.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/ray/train/v2/tests/test_result.py b/python/ray/train/v2/tests/test_result.py index f4ebc95f7cb5..dbc335261bd8 100644 --- a/python/ray/train/v2/tests/test_result.py +++ b/python/ray/train/v2/tests/test_result.py @@ -117,8 +117,7 @@ def test_get_best_checkpoint(): def test_get_best_checkpoint_nested_metrics(): - """Test that get_best_checkpoint works with nested metric dictionaries. - """ + """Test that get_best_checkpoint works with nested metric dictionaries.""" # Test with nested metric structure res = Result( metrics={},