diff --git a/skyrl/train/generators/utils.py b/skyrl/train/generators/utils.py index 39410b9d6b..7eadbc86cc 100644 --- a/skyrl/train/generators/utils.py +++ b/skyrl/train/generators/utils.py @@ -225,6 +225,14 @@ def get_metrics_from_generator_output(generator_output: GeneratorOutput, uids: L ) +def _flatten_field(generator_outputs: List[GeneratorOutput], key: str) -> list: + """Concatenate a per-sample list-valued field across generator outputs in O(N_total).""" + flat = [] + for go in generator_outputs: + flat.extend(go[key]) + return flat + + def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput]) -> GeneratorOutput: """ Concatenate the generator outputs of multiple batches. @@ -238,31 +246,26 @@ def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput]) -> G raise ValueError( "generator outputs are expected to all have null rollout_logprobs or all non-null, but received a mix" ) + first = generator_outputs[0] result: GeneratorOutput = { - "prompt_token_ids": sum([output["prompt_token_ids"] for output in generator_outputs], []), - "response_ids": sum([output["response_ids"] for output in generator_outputs], []), - "rewards": sum([output["rewards"] for output in generator_outputs], []), - "loss_masks": sum([output["loss_masks"] for output in generator_outputs], []), + "prompt_token_ids": _flatten_field(generator_outputs, "prompt_token_ids"), + "response_ids": _flatten_field(generator_outputs, "response_ids"), + "rewards": _flatten_field(generator_outputs, "rewards"), + "loss_masks": _flatten_field(generator_outputs, "loss_masks"), "stop_reasons": ( - sum([output["stop_reasons"] for output in generator_outputs], []) - if "stop_reasons" in generator_outputs[0] and generator_outputs[0]["stop_reasons"] is not None - else None + _flatten_field(generator_outputs, "stop_reasons") if first.get("stop_reasons") is not None else None ), "rollout_logprobs": ( - sum([output["rollout_logprobs"] for output in generator_outputs], []) - if generator_outputs[0]["rollout_logprobs"] is not None - else None + _flatten_field(generator_outputs, "rollout_logprobs") if first.get("rollout_logprobs") is not None else None ), } # propagate additional keys with list values as-is - additional_keys = [ - key for key in generator_outputs[0] if key not in result and isinstance(generator_outputs[0][key], list) - ] + additional_keys = [key for key in first if key not in result and isinstance(first[key], list)] if len(additional_keys): logger.info(f"Attempting to concatenate values for additional keys {additional_keys}") for key in additional_keys: - result[key] = sum([generator_output[key] for generator_output in generator_outputs], []) + result[key] = _flatten_field(generator_outputs, key) # Re-aggregate rollout metrics rollout_metrics = get_rollout_metrics(result["response_ids"], result["rewards"]) diff --git a/tests/train/generators/test_generator_output_utils.py b/tests/train/generators/test_generator_output_utils.py new file mode 100644 index 0000000000..53c68a378b --- /dev/null +++ b/tests/train/generators/test_generator_output_utils.py @@ -0,0 +1,118 @@ +""" +uv run --extra dev --isolated pytest tests/train/generators/test_generator_output_utils.py +""" + +import numpy as np + +from skyrl.train.generators.base import GeneratorOutput +from skyrl.train.generators.utils import ( + concatenate_generator_outputs, + get_metrics_from_generator_output, +) + + +def test_generator_output_concatenation(): + # First ensure that the GeneratorOutput fields are what we expect + expected_fields = [ + "prompt_token_ids", + "response_ids", + "rewards", + "loss_masks", + "stop_reasons", + "rollout_metrics", + "rollout_logprobs", + "rollout_expert_indices", + # optional but present in the signature + "trajectory_ids", + "is_last_step", + "pixel_values", + "image_grid_thw", + ] + assert set(GeneratorOutput.__annotations__.keys()) == set(expected_fields), ( + "GeneratorOutput fields are not what we expect. " + "Please update the test and `concatenate_generator_outputs()` to reflect the new fields." + "It is needed to help Trainer.eval() record the full GeneratorOutput information." + ) + + generator_output_1: GeneratorOutput = { + "prompt_token_ids": [[1, 2], [3, 4]], + "response_ids": [[1, 2], [3, 4]], + "rewards": [1.0, 2.0], + "loss_masks": [[1, 1], [1, 1]], + "stop_reasons": ["stop", "stop"], + "rollout_logprobs": [[0.1, 0.2], [0.3, 0.4]], + } + + generator_output_2: GeneratorOutput = { + "prompt_token_ids": [[5, 6, 7], [8]], + "response_ids": [[5, 6, 7], [8]], + "rewards": [2.0, 3.0], + "loss_masks": [[1, 1, 1], [1]], + "stop_reasons": ["stop", "stop"], + "rollout_logprobs": [[0.5, 0.6, 0.7], [0.8]], + } + + generator_outputs = [generator_output_1, generator_output_2] + concatenated_output = concatenate_generator_outputs(generator_outputs) + + assert concatenated_output["prompt_token_ids"] == [[1, 2], [3, 4], [5, 6, 7], [8]] + assert concatenated_output["response_ids"] == [[1, 2], [3, 4], [5, 6, 7], [8]] + assert concatenated_output["rewards"] == [1.0, 2.0, 2.0, 3.0] + assert concatenated_output["loss_masks"] == [[1, 1], [1, 1], [1, 1, 1], [1]] + assert concatenated_output["stop_reasons"] == ["stop", "stop", "stop", "stop"] + assert concatenated_output["rollout_logprobs"] == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6, 0.7], [0.8]] + + # Validate rollout metrics + expected_rollout_metrics = { + "generate/min_num_tokens": 1, + "generate/max_num_tokens": 3, + "generate/avg_num_tokens": 2.0, + "generate/std_num_tokens": np.std([2, 2, 3, 1]).item(), + "generate/avg_tokens_non_zero_rewards": 2.0, + "generate/avg_tokens_zero_rewards": 0, + } + assert concatenated_output["rollout_metrics"].keys() == expected_rollout_metrics.keys() + for key, value in expected_rollout_metrics.items(): + np.testing.assert_allclose(concatenated_output["rollout_metrics"][key], value) + + +def test_get_metrics_from_generator_output(): + # Per trajectory rewards, where rewards are List[float] + generator_output: GeneratorOutput = { + "prompt_token_ids": [[1, 2], [3, 4]], + "response_ids": [[1, 2], [3, 4]], + "rewards": [1.0, 2.0], + "loss_masks": [[1, 1], [1, 1]], + "stop_reasons": ["stop", "stop"], + "rollout_logprobs": None, + } + uids = ["a", "b"] + metrics = get_metrics_from_generator_output(generator_output, uids) + assert metrics["avg_score"] == 1.5 + assert metrics["pass_at_n"] == 1.0 + assert metrics["mean_positive_reward"] == 1.5 + + # Per token rewards, where rewards are List[List[float]], so for pass_at_n we use the last + # token's reward to signify the trajectory's reward + generator_output["rewards"] = [[1.0, 0.0], [0.0, 1.0]] + uids = ["a", "b"] + metrics = get_metrics_from_generator_output(generator_output, uids) + assert metrics["avg_score"] == 1.0 + assert metrics["pass_at_n"] == 0.5 + assert metrics["mean_positive_reward"] == 1.0 + + # Mixed rewards with some negative rewards + generator_output["rewards"] = [-1.0, 2.0] + uids = ["a", "b"] + metrics = get_metrics_from_generator_output(generator_output, uids) + assert metrics["avg_score"] == 0.5 + assert metrics["pass_at_n"] == 0.5 + assert metrics["mean_positive_reward"] == 1.0 + + # Mixed per-token rewards with negatives - per-token rewards + generator_output["rewards"] = [[1.0, -1.0], [-0.5, 0.5]] + uids = ["a", "b"] + metrics = get_metrics_from_generator_output(generator_output, uids) + assert metrics["avg_score"] == 0.0 + assert metrics["pass_at_n"] == 0.5 + assert metrics["mean_positive_reward"] == 0.75 diff --git a/tests/train/generators/test_skyrl_gym_generator.py b/tests/train/generators/test_skyrl_gym_generator.py index 803c0157f7..cf973457ae 100644 --- a/tests/train/generators/test_skyrl_gym_generator.py +++ b/tests/train/generators/test_skyrl_gym_generator.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List from unittest.mock import AsyncMock, MagicMock, patch -import numpy as np import pytest from skyrl.train.config import ChatTemplateConfig, GeneratorConfig @@ -15,10 +14,6 @@ GeneratorOutput, ) from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator -from skyrl.train.generators.utils import ( - concatenate_generator_outputs, - get_metrics_from_generator_output, -) from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput # Mock constants, where 4 is the eos token id @@ -361,113 +356,6 @@ async def test_generate_batched(mock_make, mock_tokenizer, mock_llm, mock_env, g assert generator_output["loss_masks"][0] == [1] * len(MOCK_LLM_OUTPUT_IDS) -def test_generator_output_concatenation(): - # First ensure that the GeneratorOutput fields are what we expect - expected_fields = [ - "prompt_token_ids", - "response_ids", - "rewards", - "loss_masks", - "stop_reasons", - "rollout_metrics", - "rollout_logprobs", - "rollout_expert_indices", - # optional but present in the signature - "trajectory_ids", - "is_last_step", - "pixel_values", - "image_grid_thw", - ] - assert set(GeneratorOutput.__annotations__.keys()) == set(expected_fields), ( - "GeneratorOutput fields are not what we expect. " - "Please update the test and `concatenate_generator_outputs()` to reflect the new fields." - "It is needed to help Trainer.eval() record the full GeneratorOutput information." - ) - - generator_output_1: GeneratorOutput = { - "prompt_token_ids": [[1, 2], [3, 4]], - "response_ids": [[1, 2], [3, 4]], - "rewards": [1.0, 2.0], - "loss_masks": [[1, 1], [1, 1]], - "stop_reasons": ["stop", "stop"], - "rollout_logprobs": [[0.1, 0.2], [0.3, 0.4]], - } - - generator_output_2: GeneratorOutput = { - "prompt_token_ids": [[5, 6, 7], [8]], - "response_ids": [[5, 6, 7], [8]], - "rewards": [2.0, 3.0], - "loss_masks": [[1, 1, 1], [1]], - "stop_reasons": ["stop", "stop"], - "rollout_logprobs": [[0.5, 0.6, 0.7], [0.8]], - } - - generator_outputs = [generator_output_1, generator_output_2] - concatenated_output = concatenate_generator_outputs(generator_outputs) - - assert concatenated_output["prompt_token_ids"] == [[1, 2], [3, 4], [5, 6, 7], [8]] - assert concatenated_output["response_ids"] == [[1, 2], [3, 4], [5, 6, 7], [8]] - assert concatenated_output["rewards"] == [1.0, 2.0, 2.0, 3.0] - assert concatenated_output["loss_masks"] == [[1, 1], [1, 1], [1, 1, 1], [1]] - assert concatenated_output["stop_reasons"] == ["stop", "stop", "stop", "stop"] - assert concatenated_output["rollout_logprobs"] == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6, 0.7], [0.8]] - - # Validate rollout metrics - expected_rollout_metrics = { - "generate/min_num_tokens": 1, - "generate/max_num_tokens": 3, - "generate/avg_num_tokens": 2.0, - "generate/std_num_tokens": np.std([2, 2, 3, 1]).item(), - "generate/avg_tokens_non_zero_rewards": 2.0, - "generate/avg_tokens_zero_rewards": 0, - } - assert concatenated_output["rollout_metrics"].keys() == expected_rollout_metrics.keys() - for key, value in expected_rollout_metrics.items(): - np.testing.assert_allclose(concatenated_output["rollout_metrics"][key], value) - - -def test_get_metrics_from_generator_output(): - # Per trajectory rewards, where rewards are List[float] - generator_output: GeneratorOutput = { - "prompt_token_ids": [[1, 2], [3, 4]], - "response_ids": [[1, 2], [3, 4]], - "rewards": [1.0, 2.0], - "loss_masks": [[1, 1], [1, 1]], - "stop_reasons": ["stop", "stop"], - "rollout_logprobs": None, - } - uids = ["a", "b"] - metrics = get_metrics_from_generator_output(generator_output, uids) - assert metrics["avg_score"] == 1.5 - assert metrics["pass_at_n"] == 1.0 - assert metrics["mean_positive_reward"] == 1.5 - - # Per token rewards, where rewards are List[List[float]], so for pass_at_n we use the last - # token's reward to signify the trajectory's reward - generator_output["rewards"] = [[1.0, 0.0], [0.0, 1.0]] - uids = ["a", "b"] - metrics = get_metrics_from_generator_output(generator_output, uids) - assert metrics["avg_score"] == 1.0 - assert metrics["pass_at_n"] == 0.5 - assert metrics["mean_positive_reward"] == 1.0 - - # Mixed rewards with some negative rewards - generator_output["rewards"] = [-1.0, 2.0] - uids = ["a", "b"] - metrics = get_metrics_from_generator_output(generator_output, uids) - assert metrics["avg_score"] == 0.5 - assert metrics["pass_at_n"] == 0.5 - assert metrics["mean_positive_reward"] == 1.0 - - # Mixed per-token rewards with negatives - per-token rewards - generator_output["rewards"] = [[1.0, -1.0], [-0.5, 0.5]] - uids = ["a", "b"] - metrics = get_metrics_from_generator_output(generator_output, uids) - assert metrics["avg_score"] == 0.0 - assert metrics["pass_at_n"] == 0.5 - assert metrics["mean_positive_reward"] == 0.75 - - @pytest.mark.asyncio @pytest.mark.parametrize("batched", [True, False]) @patch("skyrl_gym.make")