Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions skyrl/train/fully_async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,11 +623,13 @@ def convert_generation_group_mini_batch_to_training_input(
uids = []
stalenesses = []
staleness_violation_count = 0
group_size = len(cur_generation_group_mini_batch[0].generator_output["response_ids"])
for cur_generated_output_group in cur_generation_group_mini_batch:
cur_staleness = self.global_step - cur_generated_output_group.global_step_when_scheduled
stalenesses.append(cur_staleness)
generator_outputs.append(cur_generated_output_group.generator_output)
# NOTE(Charlie): for step-wise training each group can contain a variable number of entries
# (n_samples_per_prompt * variable turns_per_trajectory), so the uid fanout is per-group.
group_size = len(cur_generated_output_group.generator_output["response_ids"])
uids.extend([cur_generated_output_group.uid] * group_size)

# Check staleness violation.
Expand All @@ -642,7 +644,9 @@ def convert_generation_group_mini_batch_to_training_input(
)
staleness_violation_count += 1

generator_output = concatenate_generator_outputs(generator_outputs)
generator_output = concatenate_generator_outputs(
generator_outputs, step_wise=self.cfg.generator.step_wise_trajectories
)
assert generator_output["rollout_metrics"] is not None, "Rollout metrics should be non-null."
self.all_metrics.update(generator_output["rollout_metrics"])

Expand Down
11 changes: 8 additions & 3 deletions skyrl/train/generators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,17 @@ def _flatten_field(generator_outputs: List[GeneratorOutput], key: str) -> list:
return flat


def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput]) -> GeneratorOutput:
def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput], step_wise: bool = False) -> GeneratorOutput:
"""
Concatenate the generator outputs of multiple batches.
Concatenate the generator outputs of multiple batches. Then validate the concatenated result.

We only aggregate rollout metrics the can deduced by responses and rewards, but not
those that use `env_metrics` or `env_classes`.

Args:
generator_outputs: Per-batch generator outputs to concatenate.
step_wise: If True, validate step-wise specific fields on the concatenated result
(e.g. `is_last_step`, `trajectory_ids`, contiguous trajectory ordering).
"""
assert len(generator_outputs) > 0
has_rollout_logprobs = [output.get("rollout_logprobs") is not None for output in generator_outputs]
Expand Down Expand Up @@ -276,7 +281,7 @@ def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput]) -> G
from skyrl.train.utils.trainer_utils import validate_generator_output

num_prompts = len(result["prompt_token_ids"])
validate_generator_output(num_prompts, result)
validate_generator_output(num_prompts, result, step_wise=step_wise)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While the step_wise flag is now correctly passed to validate_generator_output, the rollout_metrics re-calculation performed just above (at line 276 in the source) does not account for step-wise trajectories. In step-wise mode, result["response_ids"] contains individual turns, meaning get_rollout_metrics will compute turn-level statistics (e.g., average tokens per turn) instead of trajectory-level statistics.

Consider leveraging the new step_wise flag to filter or aggregate metrics appropriately (e.g., by filtering for is_last_step or grouping by trajectory_id) to ensure that the reported rollout_metrics are consistent with trajectory-level expectations.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will leave as future TODO


return result

Expand Down
Loading