diff --git a/skyrl/train/fully_async_trainer.py b/skyrl/train/fully_async_trainer.py index a04bfd6cf4..d1bc6776b1 100644 --- a/skyrl/train/fully_async_trainer.py +++ b/skyrl/train/fully_async_trainer.py @@ -623,11 +623,13 @@ def convert_generation_group_mini_batch_to_training_input( uids = [] stalenesses = [] staleness_violation_count = 0 - group_size = len(cur_generation_group_mini_batch[0].generator_output["response_ids"]) for cur_generated_output_group in cur_generation_group_mini_batch: cur_staleness = self.global_step - cur_generated_output_group.global_step_when_scheduled stalenesses.append(cur_staleness) generator_outputs.append(cur_generated_output_group.generator_output) + # NOTE(Charlie): for step-wise training each group can contain a variable number of entries + # (n_samples_per_prompt * variable turns_per_trajectory), so the uid fanout is per-group. + group_size = len(cur_generated_output_group.generator_output["response_ids"]) uids.extend([cur_generated_output_group.uid] * group_size) # Check staleness violation. @@ -642,7 +644,9 @@ def convert_generation_group_mini_batch_to_training_input( ) staleness_violation_count += 1 - generator_output = concatenate_generator_outputs(generator_outputs) + generator_output = concatenate_generator_outputs( + generator_outputs, step_wise=self.cfg.generator.step_wise_trajectories + ) assert generator_output["rollout_metrics"] is not None, "Rollout metrics should be non-null." self.all_metrics.update(generator_output["rollout_metrics"]) diff --git a/skyrl/train/generators/utils.py b/skyrl/train/generators/utils.py index 7eadbc86cc..331b890c15 100644 --- a/skyrl/train/generators/utils.py +++ b/skyrl/train/generators/utils.py @@ -233,12 +233,17 @@ def _flatten_field(generator_outputs: List[GeneratorOutput], key: str) -> list: return flat -def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput]) -> GeneratorOutput: +def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput], step_wise: bool = False) -> GeneratorOutput: """ - Concatenate the generator outputs of multiple batches. + Concatenate the generator outputs of multiple batches. Then validate the concatenated result. We only aggregate rollout metrics the can deduced by responses and rewards, but not those that use `env_metrics` or `env_classes`. + + Args: + generator_outputs: Per-batch generator outputs to concatenate. + step_wise: If True, validate step-wise specific fields on the concatenated result + (e.g. `is_last_step`, `trajectory_ids`, contiguous trajectory ordering). """ assert len(generator_outputs) > 0 has_rollout_logprobs = [output.get("rollout_logprobs") is not None for output in generator_outputs] @@ -276,7 +281,7 @@ def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput]) -> G from skyrl.train.utils.trainer_utils import validate_generator_output num_prompts = len(result["prompt_token_ids"]) - validate_generator_output(num_prompts, result) + validate_generator_output(num_prompts, result, step_wise=step_wise) return result