Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions skyrl/train/generators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ def get_metrics_from_generator_output(generator_output: GeneratorOutput, uids: L
)


def _flatten_field(generator_outputs: List[GeneratorOutput], key: str) -> list:
"""Concatenate a per-sample list-valued field across generator outputs in O(N_total)."""
flat = []
for go in generator_outputs:
flat.extend(go[key])
return flat


def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput]) -> GeneratorOutput:
"""
Concatenate the generator outputs of multiple batches.
Expand All @@ -238,31 +246,26 @@ def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput]) -> G
raise ValueError(
"generator outputs are expected to all have null rollout_logprobs or all non-null, but received a mix"
)
first = generator_outputs[0]
result: GeneratorOutput = {
"prompt_token_ids": sum([output["prompt_token_ids"] for output in generator_outputs], []),
"response_ids": sum([output["response_ids"] for output in generator_outputs], []),
"rewards": sum([output["rewards"] for output in generator_outputs], []),
"loss_masks": sum([output["loss_masks"] for output in generator_outputs], []),
"prompt_token_ids": _flatten_field(generator_outputs, "prompt_token_ids"),
"response_ids": _flatten_field(generator_outputs, "response_ids"),
"rewards": _flatten_field(generator_outputs, "rewards"),
"loss_masks": _flatten_field(generator_outputs, "loss_masks"),
"stop_reasons": (
sum([output["stop_reasons"] for output in generator_outputs], [])
if "stop_reasons" in generator_outputs[0] and generator_outputs[0]["stop_reasons"] is not None
else None
_flatten_field(generator_outputs, "stop_reasons") if first.get("stop_reasons") is not None else None
),
"rollout_logprobs": (
sum([output["rollout_logprobs"] for output in generator_outputs], [])
if generator_outputs[0]["rollout_logprobs"] is not None
else None
_flatten_field(generator_outputs, "rollout_logprobs") if first.get("rollout_logprobs") is not None else None
),
}

# propagate additional keys with list values as-is
additional_keys = [
key for key in generator_outputs[0] if key not in result and isinstance(generator_outputs[0][key], list)
]
additional_keys = [key for key in first if key not in result and isinstance(first[key], list)]
if len(additional_keys):
logger.info(f"Attempting to concatenate values for additional keys {additional_keys}")
for key in additional_keys:
result[key] = sum([generator_output[key] for generator_output in generator_outputs], [])
result[key] = _flatten_field(generator_outputs, key)

# Re-aggregate rollout metrics
rollout_metrics = get_rollout_metrics(result["response_ids"], result["rewards"])
Expand Down
118 changes: 118 additions & 0 deletions tests/train/generators/test_generator_output_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
uv run --extra dev --isolated pytest tests/train/generators/test_generator_output_utils.py
"""

import numpy as np

from skyrl.train.generators.base import GeneratorOutput
from skyrl.train.generators.utils import (
concatenate_generator_outputs,
get_metrics_from_generator_output,
)


def test_generator_output_concatenation():
# First ensure that the GeneratorOutput fields are what we expect
expected_fields = [
"prompt_token_ids",
"response_ids",
"rewards",
"loss_masks",
"stop_reasons",
"rollout_metrics",
"rollout_logprobs",
"rollout_expert_indices",
# optional but present in the signature
"trajectory_ids",
"is_last_step",
"pixel_values",
"image_grid_thw",
]
assert set(GeneratorOutput.__annotations__.keys()) == set(expected_fields), (
"GeneratorOutput fields are not what we expect. "
"Please update the test and `concatenate_generator_outputs()` to reflect the new fields."
"It is needed to help Trainer.eval() record the full GeneratorOutput information."
)

generator_output_1: GeneratorOutput = {
"prompt_token_ids": [[1, 2], [3, 4]],
"response_ids": [[1, 2], [3, 4]],
"rewards": [1.0, 2.0],
"loss_masks": [[1, 1], [1, 1]],
"stop_reasons": ["stop", "stop"],
"rollout_logprobs": [[0.1, 0.2], [0.3, 0.4]],
}

generator_output_2: GeneratorOutput = {
"prompt_token_ids": [[5, 6, 7], [8]],
"response_ids": [[5, 6, 7], [8]],
"rewards": [2.0, 3.0],
"loss_masks": [[1, 1, 1], [1]],
"stop_reasons": ["stop", "stop"],
"rollout_logprobs": [[0.5, 0.6, 0.7], [0.8]],
}

generator_outputs = [generator_output_1, generator_output_2]
concatenated_output = concatenate_generator_outputs(generator_outputs)

assert concatenated_output["prompt_token_ids"] == [[1, 2], [3, 4], [5, 6, 7], [8]]
assert concatenated_output["response_ids"] == [[1, 2], [3, 4], [5, 6, 7], [8]]
assert concatenated_output["rewards"] == [1.0, 2.0, 2.0, 3.0]
assert concatenated_output["loss_masks"] == [[1, 1], [1, 1], [1, 1, 1], [1]]
assert concatenated_output["stop_reasons"] == ["stop", "stop", "stop", "stop"]
assert concatenated_output["rollout_logprobs"] == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6, 0.7], [0.8]]

# Validate rollout metrics
expected_rollout_metrics = {
"generate/min_num_tokens": 1,
"generate/max_num_tokens": 3,
"generate/avg_num_tokens": 2.0,
"generate/std_num_tokens": np.std([2, 2, 3, 1]).item(),
"generate/avg_tokens_non_zero_rewards": 2.0,
"generate/avg_tokens_zero_rewards": 0,
}
assert concatenated_output["rollout_metrics"].keys() == expected_rollout_metrics.keys()
for key, value in expected_rollout_metrics.items():
np.testing.assert_allclose(concatenated_output["rollout_metrics"][key], value)


def test_get_metrics_from_generator_output():
# Per trajectory rewards, where rewards are List[float]
generator_output: GeneratorOutput = {
"prompt_token_ids": [[1, 2], [3, 4]],
"response_ids": [[1, 2], [3, 4]],
"rewards": [1.0, 2.0],
"loss_masks": [[1, 1], [1, 1]],
"stop_reasons": ["stop", "stop"],
"rollout_logprobs": None,
}
uids = ["a", "b"]
metrics = get_metrics_from_generator_output(generator_output, uids)
assert metrics["avg_score"] == 1.5
assert metrics["pass_at_n"] == 1.0
assert metrics["mean_positive_reward"] == 1.5

# Per token rewards, where rewards are List[List[float]], so for pass_at_n we use the last
# token's reward to signify the trajectory's reward
generator_output["rewards"] = [[1.0, 0.0], [0.0, 1.0]]
uids = ["a", "b"]
metrics = get_metrics_from_generator_output(generator_output, uids)
assert metrics["avg_score"] == 1.0
assert metrics["pass_at_n"] == 0.5
assert metrics["mean_positive_reward"] == 1.0

# Mixed rewards with some negative rewards
generator_output["rewards"] = [-1.0, 2.0]
uids = ["a", "b"]
metrics = get_metrics_from_generator_output(generator_output, uids)
assert metrics["avg_score"] == 0.5
assert metrics["pass_at_n"] == 0.5
assert metrics["mean_positive_reward"] == 1.0

# Mixed per-token rewards with negatives - per-token rewards
generator_output["rewards"] = [[1.0, -1.0], [-0.5, 0.5]]
uids = ["a", "b"]
metrics = get_metrics_from_generator_output(generator_output, uids)
assert metrics["avg_score"] == 0.0
assert metrics["pass_at_n"] == 0.5
assert metrics["mean_positive_reward"] == 0.75
112 changes: 0 additions & 112 deletions tests/train/generators/test_skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any, Dict, List
from unittest.mock import AsyncMock, MagicMock, patch

import numpy as np
import pytest

from skyrl.train.config import ChatTemplateConfig, GeneratorConfig
Expand All @@ -15,10 +14,6 @@
GeneratorOutput,
)
from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator
from skyrl.train.generators.utils import (
concatenate_generator_outputs,
get_metrics_from_generator_output,
)
from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput

# Mock constants, where 4 is the eos token id
Expand Down Expand Up @@ -361,113 +356,6 @@ async def test_generate_batched(mock_make, mock_tokenizer, mock_llm, mock_env, g
assert generator_output["loss_masks"][0] == [1] * len(MOCK_LLM_OUTPUT_IDS)


def test_generator_output_concatenation():
# First ensure that the GeneratorOutput fields are what we expect
expected_fields = [
"prompt_token_ids",
"response_ids",
"rewards",
"loss_masks",
"stop_reasons",
"rollout_metrics",
"rollout_logprobs",
"rollout_expert_indices",
# optional but present in the signature
"trajectory_ids",
"is_last_step",
"pixel_values",
"image_grid_thw",
]
assert set(GeneratorOutput.__annotations__.keys()) == set(expected_fields), (
"GeneratorOutput fields are not what we expect. "
"Please update the test and `concatenate_generator_outputs()` to reflect the new fields."
"It is needed to help Trainer.eval() record the full GeneratorOutput information."
)

generator_output_1: GeneratorOutput = {
"prompt_token_ids": [[1, 2], [3, 4]],
"response_ids": [[1, 2], [3, 4]],
"rewards": [1.0, 2.0],
"loss_masks": [[1, 1], [1, 1]],
"stop_reasons": ["stop", "stop"],
"rollout_logprobs": [[0.1, 0.2], [0.3, 0.4]],
}

generator_output_2: GeneratorOutput = {
"prompt_token_ids": [[5, 6, 7], [8]],
"response_ids": [[5, 6, 7], [8]],
"rewards": [2.0, 3.0],
"loss_masks": [[1, 1, 1], [1]],
"stop_reasons": ["stop", "stop"],
"rollout_logprobs": [[0.5, 0.6, 0.7], [0.8]],
}

generator_outputs = [generator_output_1, generator_output_2]
concatenated_output = concatenate_generator_outputs(generator_outputs)

assert concatenated_output["prompt_token_ids"] == [[1, 2], [3, 4], [5, 6, 7], [8]]
assert concatenated_output["response_ids"] == [[1, 2], [3, 4], [5, 6, 7], [8]]
assert concatenated_output["rewards"] == [1.0, 2.0, 2.0, 3.0]
assert concatenated_output["loss_masks"] == [[1, 1], [1, 1], [1, 1, 1], [1]]
assert concatenated_output["stop_reasons"] == ["stop", "stop", "stop", "stop"]
assert concatenated_output["rollout_logprobs"] == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6, 0.7], [0.8]]

# Validate rollout metrics
expected_rollout_metrics = {
"generate/min_num_tokens": 1,
"generate/max_num_tokens": 3,
"generate/avg_num_tokens": 2.0,
"generate/std_num_tokens": np.std([2, 2, 3, 1]).item(),
"generate/avg_tokens_non_zero_rewards": 2.0,
"generate/avg_tokens_zero_rewards": 0,
}
assert concatenated_output["rollout_metrics"].keys() == expected_rollout_metrics.keys()
for key, value in expected_rollout_metrics.items():
np.testing.assert_allclose(concatenated_output["rollout_metrics"][key], value)


def test_get_metrics_from_generator_output():
# Per trajectory rewards, where rewards are List[float]
generator_output: GeneratorOutput = {
"prompt_token_ids": [[1, 2], [3, 4]],
"response_ids": [[1, 2], [3, 4]],
"rewards": [1.0, 2.0],
"loss_masks": [[1, 1], [1, 1]],
"stop_reasons": ["stop", "stop"],
"rollout_logprobs": None,
}
uids = ["a", "b"]
metrics = get_metrics_from_generator_output(generator_output, uids)
assert metrics["avg_score"] == 1.5
assert metrics["pass_at_n"] == 1.0
assert metrics["mean_positive_reward"] == 1.5

# Per token rewards, where rewards are List[List[float]], so for pass_at_n we use the last
# token's reward to signify the trajectory's reward
generator_output["rewards"] = [[1.0, 0.0], [0.0, 1.0]]
uids = ["a", "b"]
metrics = get_metrics_from_generator_output(generator_output, uids)
assert metrics["avg_score"] == 1.0
assert metrics["pass_at_n"] == 0.5
assert metrics["mean_positive_reward"] == 1.0

# Mixed rewards with some negative rewards
generator_output["rewards"] = [-1.0, 2.0]
uids = ["a", "b"]
metrics = get_metrics_from_generator_output(generator_output, uids)
assert metrics["avg_score"] == 0.5
assert metrics["pass_at_n"] == 0.5
assert metrics["mean_positive_reward"] == 1.0

# Mixed per-token rewards with negatives - per-token rewards
generator_output["rewards"] = [[1.0, -1.0], [-0.5, 0.5]]
uids = ["a", "b"]
metrics = get_metrics_from_generator_output(generator_output, uids)
assert metrics["avg_score"] == 0.0
assert metrics["pass_at_n"] == 0.5
assert metrics["mean_positive_reward"] == 0.75


@pytest.mark.asyncio
@pytest.mark.parametrize("batched", [True, False])
@patch("skyrl_gym.make")
Expand Down
Loading