diff --git a/docs/content/docs/algorithms/custom_algorithms.mdx b/docs/content/docs/algorithms/custom_algorithms.mdx index b00402716b..ce7d93036d 100644 --- a/docs/content/docs/algorithms/custom_algorithms.mdx +++ b/docs/content/docs/algorithms/custom_algorithms.mdx @@ -98,7 +98,9 @@ We show the outline of creating a custom trainer below, and you can find a full ```python class CustomTrainer(RayPPOTrainer): @torch.no_grad() - def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: + def postprocess_generator_output( + self, generator_output: GeneratorOutput, uids: List[str] + ) -> Tuple[GeneratorOutput, List[str]]: # apply custom reward penalties ... # use base class impl for metrics and per-token reward conversion diff --git a/docs/content/docs/algorithms/dapo.mdx b/docs/content/docs/algorithms/dapo.mdx index 22e9b86065..3a95d8e50e 100644 --- a/docs/content/docs/algorithms/dapo.mdx +++ b/docs/content/docs/algorithms/dapo.mdx @@ -80,7 +80,9 @@ We provide an example of this in `examples/train/algorithms/dapo/main_dapo.py`, ```python title="examples/train/algorithms/dapo/main_dapo.py" class DAPOTrainer(RayPPOTrainer): @torch.no_grad() - def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: + def postprocess_generator_output( + self, generator_output: GeneratorOutput, uids: List[str] + ) -> Tuple[GeneratorOutput, List[str]]: # apply soft overlong punishment overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer.len overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer.penalty_factor diff --git a/docs/content/docs/tutorials/agent-integration.mdx b/docs/content/docs/tutorials/agent-integration.mdx index d869a62279..551be8a18a 100644 --- a/docs/content/docs/tutorials/agent-integration.mdx +++ b/docs/content/docs/tutorials/agent-integration.mdx @@ -84,6 +84,6 @@ Your agent harness can still use `/chat/completions` with tool call parsing, sin **Cons:** - Training time can grow: O(T^2) vs O(T), since each trajectory of T turns becomes T sequences to forward (each with a growing prefix), as opposed to 1 sequence. - - SkyRL will support prefix-aware merging of per-step sequences when the prefix matches (WIP), which brings the cost back to O(T) in the common case. + - SkyRL support prefix-aware merging of per-step sequences when the prefix matches with config flag `generator.merge_stepwise_output`, which can reduce the O(T^2) cost if chat history is linearly appending across turns and there is no token mismatch. See https://github.com/NovaSky-AI/SkyRL/pull/1532 For the full details on how to structure the `GeneratorOutput` for step-wise training, including the required fields, invariants, and a concrete example, see: [Step-Wise Training](step-wise-training). diff --git a/docs/content/docs/tutorials/step-wise-training.mdx b/docs/content/docs/tutorials/step-wise-training.mdx index 071f222f86..83a34ba6f5 100644 --- a/docs/content/docs/tutorials/step-wise-training.mdx +++ b/docs/content/docs/tutorials/step-wise-training.mdx @@ -28,7 +28,7 @@ When step-wise is enabled, a batch of T trajectories with an average of M turns - **Each mini-batch contains the sequences for exactly `policy_mini_batch_size` prompts**, regardless of how many turns those prompts produced. This means the number of mini-batches (and hence optimizer steps) per training batch is always `train_batch_size / policy_mini_batch_size`, independent of the number of turns. This also means that the actual mini batch size (number of sequences) trained in each mini batch can vary. Each mini batch always leads to a single optimizer step. - **Advantages are computed on last steps only**, then broadcast to all steps of the same trajectory. This is mathematically equivalent to non-step-wise advantage computation for GRPO. -- **Training time grows as O(T²) vs O(T)**, since each trajectory of T turns becomes T sequences to forward (each with a growing prompt prefix), as opposed to 1 sequence. SkyRL will support prefix-aware merging of per-step sequences when the prefix matches (WIP), which brings the cost back to O(T) in the common case. +- **Training time grows as O(T²) vs O(T)**, since each trajectory of T turns becomes T sequences to forward (each with a growing prompt prefix), as opposed to 1 sequence. SkyRL supports prefix-aware merging of per-step sequences when the prefix matches with config flag `generator.merge_stepwise_output`, which can reduce the O(T²) cost if chat history is linearly appending across turns and there is no token mismatch. See https://github.com/NovaSky-AI/SkyRL/pull/1532 - **Metrics** like `generate/avg_num_tokens` and `generate/avg_response_length` are per-turn rather than per-trajectory, since each training sample is a single turn. Some algorithms have their behavior altered by step-wise decomposition, since each turn is now treated as its own sequence: diff --git a/examples/train/algorithms/dapo/README.md b/examples/train/algorithms/dapo/README.md index 6870c51fd9..76e6047872 100644 --- a/examples/train/algorithms/dapo/README.md +++ b/examples/train/algorithms/dapo/README.md @@ -74,7 +74,9 @@ To enable soft overlong punishment, you can create a custom trainer class and ov ```python class DAPOTrainer(RayPPOTrainer): @torch.no_grad() - def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: + def postprocess_generator_output( + self, generator_output: GeneratorOutput, uids: List[str] + ) -> Tuple[GeneratorOutput, List[str]]: # apply soft overlong punishment overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer_len overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer_penalty_factor diff --git a/examples/train/algorithms/dapo/main_dapo.py b/examples/train/algorithms/dapo/main_dapo.py index 2b846e7c29..4ae78074aa 100644 --- a/examples/train/algorithms/dapo/main_dapo.py +++ b/examples/train/algorithms/dapo/main_dapo.py @@ -7,7 +7,7 @@ import ray import torch from dataclasses import dataclass -from typing import List +from typing import List, Tuple from skyrl.train.config import AlgorithmConfig, make_config from skyrl.train.trainer import RayPPOTrainer @@ -36,7 +36,9 @@ class DAPOTrainer(RayPPOTrainer): """ @torch.no_grad() - def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: + def postprocess_generator_output( + self, generator_output: GeneratorOutput, uids: List[str] + ) -> Tuple[GeneratorOutput, List[str]]: # NOTE (sumanthrh): Given the usage of `make_config`, the algorithm config subclass for DAPO is # created dynamically and thus IDEs will not be able to resolve the attributes # For better typing, you can always define a custom subclass of DAPOConfig manually. diff --git a/examples/train/algorithms/dapo/main_dapo_fully_async.py b/examples/train/algorithms/dapo/main_dapo_fully_async.py index 3fcb82c44e..4540355a92 100644 --- a/examples/train/algorithms/dapo/main_dapo_fully_async.py +++ b/examples/train/algorithms/dapo/main_dapo_fully_async.py @@ -6,7 +6,7 @@ import ray import torch -from typing import List +from typing import List, Tuple from skyrl.train.fully_async_trainer import FullyAsyncRayPPOTrainer from skyrl.train.utils import initialize_ray, validate_cfg @@ -19,7 +19,9 @@ class FullyAsyncDAPOTrainer(FullyAsyncRayPPOTrainer): @torch.no_grad() - def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: + def postprocess_generator_output( + self, generator_output: GeneratorOutput, uids: List[str] + ) -> Tuple[GeneratorOutput, List[str]]: """ Overrides the postprocess_generator_output method to additionally apply DAPO specific soft overlong punishment to rewards. diff --git a/examples/train/async/async_trainer.py b/examples/train/async/async_trainer.py index a80edd730f..72fc666802 100644 --- a/examples/train/async/async_trainer.py +++ b/examples/train/async/async_trainer.py @@ -175,7 +175,7 @@ async def _run_generate_loop(self, generation_buffer: asyncio.Queue): # generation phase async with Timer("generate", self.all_timings): generator_output: GeneratorOutput = await self.generate(generator_input) - generator_output = self.postprocess_generator_output(generator_output, uids) + generator_output, uids = self.postprocess_generator_output(generator_output, uids) # Add to generation buffer await generation_buffer.put((generator_output, uids)) diff --git a/examples/train/flash_rl/main_dapo_flashrl.py b/examples/train/flash_rl/main_dapo_flashrl.py index e5a0d729e4..5d9e142601 100644 --- a/examples/train/flash_rl/main_dapo_flashrl.py +++ b/examples/train/flash_rl/main_dapo_flashrl.py @@ -7,7 +7,7 @@ import ray import torch from dataclasses import dataclass -from typing import List +from typing import List, Tuple from skyrl.train.config import SkyRLTrainConfig, AlgorithmConfig, make_config from skyrl.train.trainer import RayPPOTrainer @@ -64,7 +64,9 @@ class DAPOTrainer(RayPPOTrainer): """ @torch.no_grad() - def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: + def postprocess_generator_output( + self, generator_output: GeneratorOutput, uids: List[str] + ) -> Tuple[GeneratorOutput, List[str]]: """ Overrides the postprocess_generator_output method to additionally apply DAPO specific soft overlong punishment to rewards. @@ -73,7 +75,7 @@ def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: uids: List[str] Returns: - GeneratorOutput + (GeneratorOutput, uids) — uids may be shortened if base class applies step-wise merging. """ overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer_len overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer_penalty_factor diff --git a/examples/train/search/run_search.sh b/examples/train/search/run_search.sh index 92714d0d05..08f0859ef2 100755 --- a/examples/train/search/run_search.sh +++ b/examples/train/search/run_search.sh @@ -48,6 +48,8 @@ else MULTI_TURN_ARGS="generator.use_conversation_multi_turn=false" fi +: "${MERGE_STEPWISE:=false}" + STEP_WISE_ARGS="" if [ "$STEP_WISE" = "true" ]; then STEP_WISE_ARGS="generator.step_wise_trajectories=true" @@ -56,6 +58,9 @@ if [ "$STEP_WISE" = "true" ]; then echo "WARNING: STEP_WISE=true requires USE_CONVERSATION_MULTI_TURN=true. Enabling it automatically." MULTI_TURN_ARGS="generator.use_conversation_multi_turn=true generator.append_eos_token_after_stop_str_in_multi_turn=true" fi + if [ "$MERGE_STEPWISE" = "true" ]; then + STEP_WISE_ARGS="$STEP_WISE_ARGS generator.merge_stepwise_output=true" + fi fi uv run --isolated --frozen --extra fsdp -m skyrl.train.entrypoints.main_base \ diff --git a/examples/train/tis_correction/main_tis_dapo.py b/examples/train/tis_correction/main_tis_dapo.py index e8356222a0..157d4bce23 100644 --- a/examples/train/tis_correction/main_tis_dapo.py +++ b/examples/train/tis_correction/main_tis_dapo.py @@ -7,7 +7,7 @@ import ray import torch from dataclasses import dataclass -from typing import List +from typing import List, Tuple from skyrl.train.config import AlgorithmConfig, make_config from skyrl.train.trainer import RayPPOTrainer @@ -36,7 +36,9 @@ class DAPOTrainer(RayPPOTrainer): """ @torch.no_grad() - def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: + def postprocess_generator_output( + self, generator_output: GeneratorOutput, uids: List[str] + ) -> Tuple[GeneratorOutput, List[str]]: """ Overrides the postprocess_generator_output method to additionally apply DAPO specific soft overlong punishment to rewards. @@ -45,7 +47,7 @@ def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: uids: List[str] Returns: - GeneratorOutput + (GeneratorOutput, uids) — uids may be shortened if base class applies step-wise merging. """ overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer_len overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer_penalty_factor diff --git a/skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py b/skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py index 884f99f701..15d362287b 100644 --- a/skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py +++ b/skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py @@ -354,7 +354,7 @@ async def train(self): # 1.2 postprocess rewards with Timer("postprocess_generator_output", self.all_timings): - generator_output = self.postprocess_generator_output(generator_output, uids) + generator_output, uids = self.postprocess_generator_output(generator_output, uids) # 2. print example just for debugging vis = self.tokenizer.decode(generator_output["response_ids"][0]) diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 85edfd01d7..6c62469844 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -536,6 +536,9 @@ class GeneratorConfig(BaseConfig): """Can differ from the trainer's ``rope_scaling``, useful for thinking models.""" rope_theta: Optional[float] = None step_wise_trajectories: bool = False + merge_stepwise_output: bool = False + """When True (and step_wise_trajectories is True), apply prefix-aware merging + to collapse multi-turn step-wise sequences into single sequences before training.""" def __post_init__(self): diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index 22396bb907..ed4d027daf 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -382,6 +382,7 @@ generator: rope_theta: ${trainer.rope_theta} step_wise_trajectories: false + merge_stepwise_output: false environment: env_class: "gsm8k" diff --git a/skyrl/train/fully_async_trainer.py b/skyrl/train/fully_async_trainer.py index d1bc6776b1..63284da242 100644 --- a/skyrl/train/fully_async_trainer.py +++ b/skyrl/train/fully_async_trainer.py @@ -662,7 +662,7 @@ def convert_generation_group_mini_batch_to_training_input( ) # Convert rewards to per-token form and compute reward metrics before training conversion - generator_output = self.postprocess_generator_output(generator_output, uids) + generator_output, uids = self.postprocess_generator_output(generator_output, uids) # print example just for debugging vis = self.tokenizer.decode(generator_output["response_ids"][0]) diff --git a/skyrl/train/generators/utils.py b/skyrl/train/generators/utils.py index 331b890c15..ea908cbe9a 100644 --- a/skyrl/train/generators/utils.py +++ b/skyrl/train/generators/utils.py @@ -582,3 +582,188 @@ def get_response_ids_and_loss_mask_from_messages( assert len(rollout_logprobs) == len(response_ids) if rollout_logprobs is not None else True return response_ids, loss_mask, rollout_logprobs + + +# ------------------------------------------- +# Prefix-aware merging for step-wise training +# ------------------------------------------- + + +def _is_prefix(maybe_prefix: List[int], candidate: List[int]) -> bool: + """Check if maybe_prefix is a prefix of candidate (or equal).""" + if len(maybe_prefix) > len(candidate): + return False + return maybe_prefix == candidate[: len(maybe_prefix)] + + +def _slice_generator_output(generator_output: GeneratorOutput, indices: List[int]) -> GeneratorOutput: + """Slice a GeneratorOutput to keep only the entries at the given indices. + + All sliced entries must share the same TrajectoryID — this helper is used by + prefix-aware merging which operates on one trajectory at a time. + """ + assert len(indices) > 0, "indices must be non-empty" + # Every key except `rollout_metrics` is either a per-entry list to slice, or None. + sliced: GeneratorOutput = {} + for key, value in generator_output.items(): + if key == "rollout_metrics": + sliced[key] = value + elif value is None: + sliced[key] = None + else: + sliced[key] = [value[i] for i in indices] + return sliced + + +def _merge_single_trajectory(gen_out: GeneratorOutput) -> GeneratorOutput: + """Greedily merge turns of a single trajectory using prefix matching. + + Takes a GeneratorOutput whose entries all belong to the same trajectory_id + and returns a merged GeneratorOutput (potentially fewer entries). + """ + # Make sure all entries in the trajectory have the same trajectory_id + trajectory_ids = gen_out.get("trajectory_ids") + assert trajectory_ids is not None, "trajectory_ids is required for prefix-aware merging" + for i in range(0, len(trajectory_ids)): + assert ( + trajectory_ids[i] == trajectory_ids[0] + ), "all entries in a single trajectory must have the same trajectory_id" + + n = len(gen_out["response_ids"]) + assert n > 0, "Expect non-empty GeneratorOutput." + is_token_level_rewards = isinstance(gen_out["rewards"][0], list) + has_logprobs = gen_out.get("rollout_logprobs") is not None + has_stop_reasons = gen_out.get("stop_reasons") is not None + + # Per-field output accumulators. + # Fields that we take from all the entries in the merge group + out_prompt_ids: List[List[int]] = [] + out_response_ids: List[List[int]] = [] + out_loss_masks: List[List[int]] = [] + out_logprobs: Optional[List[List[float]]] = [] if has_logprobs else None + # If per-token rewards, we keep appending. If per-turn rewards, we only take from the last turn. + out_rewards: list = [] + + # Fields that we only take from the last turn in the merge group + out_stop_reasons: Optional[List[str]] = [] if has_stop_reasons else None + out_trajectory_ids: list = [] + out_is_last_step: List[bool] = [] + + # Accumulator for the current merge group + acc_prompt: List[int] = list(gen_out["prompt_token_ids"][0]) + acc_response: List[int] = list(gen_out["response_ids"][0]) + acc_loss_mask: List[int] = list(gen_out["loss_masks"][0]) + acc_logprobs: Optional[List[float]] = list(gen_out["rollout_logprobs"][0]) if has_logprobs else None + acc_rewards_tokens: Optional[List[float]] = list(gen_out["rewards"][0]) if is_token_level_rewards else None + last = 0 + + def flush(): + nonlocal acc_prompt, acc_response, acc_loss_mask, acc_logprobs, acc_rewards_tokens, last + out_prompt_ids.append(acc_prompt) + out_response_ids.append(acc_response) + out_loss_masks.append(acc_loss_mask) + if has_logprobs: + out_logprobs.append(acc_logprobs) + out_rewards.append(acc_rewards_tokens if is_token_level_rewards else gen_out["rewards"][last]) + if has_stop_reasons: + out_stop_reasons.append(gen_out["stop_reasons"][last]) + out_trajectory_ids.append(gen_out["trajectory_ids"][last]) + out_is_last_step.append(gen_out["is_last_step"][last]) + + for i in range(1, n): + full_merged = acc_prompt + acc_response + + # If prompt[i-1] + response[i-1] is not a prefix of prompt[i], flush the current merge group + # and start a new group to merge. + if not _is_prefix(full_merged, gen_out["prompt_token_ids"][i]): + flush() + acc_prompt = list(gen_out["prompt_token_ids"][i]) + acc_response = list(gen_out["response_ids"][i]) + acc_loss_mask = list(gen_out["loss_masks"][i]) + acc_logprobs = list(gen_out["rollout_logprobs"][i]) if has_logprobs else None + acc_rewards_tokens = list(gen_out["rewards"][i]) if is_token_level_rewards else None + last = i + continue + + # prompt[i-1] + response[i-1] is a prefix of prompt[i], so we can merge the two turns. + # obs_delta is the newly prefilled tokens not generated by the assistant, so we need to + # properly loss mask them. + obs_delta = gen_out["prompt_token_ids"][i][len(full_merged) :] + + # Merge obs_delta to the fields, assigning zeros since it is not generated by the assistant. + acc_response.extend(obs_delta) + acc_loss_mask.extend([0] * len(obs_delta)) + if acc_logprobs is not None: + acc_logprobs.extend([0.0] * len(obs_delta)) + if acc_rewards_tokens is not None: + acc_rewards_tokens.extend([0.0] * len(obs_delta)) + + # Extend the current merge group with the next turn's fields, exactly the same preserved. + acc_response.extend(gen_out["response_ids"][i]) + acc_loss_mask.extend(gen_out["loss_masks"][i]) + if acc_logprobs is not None: + acc_logprobs.extend(gen_out["rollout_logprobs"][i]) + if acc_rewards_tokens is not None: + acc_rewards_tokens.extend(gen_out["rewards"][i]) + + last = i + + flush() + + return { + "prompt_token_ids": out_prompt_ids, + "response_ids": out_response_ids, + "rewards": out_rewards, + "loss_masks": out_loss_masks, + "stop_reasons": out_stop_reasons, + "rollout_metrics": gen_out.get("rollout_metrics", None), + "rollout_logprobs": out_logprobs, + "trajectory_ids": out_trajectory_ids, + "rollout_expert_indices": None, + "is_last_step": out_is_last_step, + } + + +def merge_stepwise_output(generator_output: GeneratorOutput) -> GeneratorOutput: + """Merge step-wise GeneratorOutput entries using prefix-aware merging. + + For consecutive turns within the same trajectory where + prompt[i] + response[i] is a prefix of prompt[i+1], + merges them into a single entry with: + - prompt from the first turn in the merge group + - response tokens concatenated with observation deltas (obs_delta) in between + - Per-token fields (loss_masks, rewards, logprobs) concatenated, with default + values (0) for obs_delta positions + - Per-turn fields (stop_reason, is_last_step, trajectory_id) taken from the + last turn in the merge group + + When the prefix condition fails between two consecutive turns, the current + merge group is flushed and a new group starts (greedy merging). + + Args: + generator_output: Step-wise GeneratorOutput with trajectory_ids and is_last_step. + + Returns: + Merged GeneratorOutput with one entry per merged group. + """ + num_samples = len(generator_output["response_ids"]) + assert ( + generator_output.get("rollout_expert_indices") is None + ), "rollout_expert_indices not supported for prefix-aware merging" + assert ( + generator_output.get("pixel_values") is None and generator_output.get("image_grid_thw") is None + ), "pixel_values and image_grid_thw not supported for step-wise training merging" + + # Split into per-trajectory GeneratorOutputs using is_last_step boundaries + # (contiguity is guaranteed by validate_generator_output with step_wise=True) + is_last_step = generator_output["is_last_step"] + trajectory_slices: List[GeneratorOutput] = [] + start = 0 + for i in range(num_samples): + if is_last_step[i]: + trajectory_slices.append(_slice_generator_output(generator_output, list(range(start, i + 1)))) + start = i + 1 + + merged_slices = [_merge_single_trajectory(s) for s in trajectory_slices] + # concatenate_generator_outputs re-aggregates rollout_metrics and validates + return concatenate_generator_outputs(merged_slices, step_wise=True) diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 104c108df2..af1b65ed11 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -59,6 +59,7 @@ ) from skyrl.train.generators.utils import ( get_metrics_from_generator_output, + merge_stepwise_output, prepare_generator_input, ) from skyrl.train.utils import ( @@ -248,9 +249,9 @@ async def train(self): # if we are not continuing sampling, we sleep the inference engine await self.inference_engine_client.sleep() - # 1.2 postprocess rewards + # 1.2 postprocess rewards (and merge step-wise turns if enabled) with Timer("postprocess_generator_output", self.all_timings): - generator_output = self.postprocess_generator_output(generator_output, uids) + generator_output, uids = self.postprocess_generator_output(generator_output, uids) # 2. print example just for debugging vis = self.tokenizer.decode(generator_output["response_ids"][0]) @@ -752,11 +753,20 @@ async def generate( return generator_output @torch.no_grad() - def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: + def postprocess_generator_output( + self, generator_output: GeneratorOutput, uids: List[str] + ) -> Tuple[GeneratorOutput, List[str]]: """ Converts to per token rewards and computes pass@N. + For step-wise training with ``merge_stepwise_output=true``, also collapses + consecutive turns sharing a common prefix into a single sequence; ``uids`` + is shortened to match. + In the future algorithm specific reward or loss mask post processing should be done here. + + Returns: + (generator_output, uids) — uids may be shorter than the input when merging. """ generator_output_for_metrics = generator_output uids_for_metrics = uids @@ -780,6 +790,21 @@ def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: uids_for_metrics, ) + # Prefix-aware merging of step-wise turns. + if self.cfg.generator.merge_stepwise_output: + assert self.cfg.generator.step_wise_trajectories, "merge_stepwise_output requires step-wise training" + num_seq_before_merge = len(generator_output["response_ids"]) + generator_output = merge_stepwise_output(generator_output) + num_seq_after_merge = len(generator_output["response_ids"]) + logger.info(f"Merged step wise: {num_seq_before_merge} sequences -> {num_seq_after_merge} sequences") + self.all_metrics.update( + { + "generate/num_seq_before_merge": num_seq_before_merge, + "generate/num_seq_after_merge": num_seq_after_merge, + } + ) + uids = [tid.instance_id for tid in generator_output["trajectory_ids"]] + # these use the full generator output rewards: Union[List[float], List[List[float]]] = generator_output["rewards"] responses: List[List[int]] = generator_output["response_ids"] @@ -815,7 +840,7 @@ def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: ) # re-assign reward but now it's per token rewards generator_output["rewards"] = per_token_rewards - return generator_output + return generator_output, uids @torch.no_grad() def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingInputBatch: diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 43d40f0f75..1e6c7013e1 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -303,6 +303,13 @@ def validate_cfg(cfg: SkyRLTrainConfig): "`token_mean_legacy` loss reduction is not supported with step-wise training. Use `token_mean` instead." ) + if cfg.generator.merge_stepwise_output and not cfg.generator.step_wise_trajectories: + raise ValueError( + "`generator.merge_stepwise_output=True` requires `generator.step_wise_trajectories=True`. " + "Prefix-aware merging operates on step-wise GeneratorOutput entries (trajectory_ids + " + "is_last_step), which only exist when step-wise training is enabled." + ) + assert cfg.trainer.algorithm.loss_reduction in ( "token_mean", "token_mean_legacy", diff --git a/tests/train/generators/test_generator_output_utils.py b/tests/train/generators/test_generator_output_utils.py index 53c68a378b..c383a3e5c9 100644 --- a/tests/train/generators/test_generator_output_utils.py +++ b/tests/train/generators/test_generator_output_utils.py @@ -2,13 +2,19 @@ uv run --extra dev --isolated pytest tests/train/generators/test_generator_output_utils.py """ +from unittest.mock import patch + import numpy as np +import pytest -from skyrl.train.generators.base import GeneratorOutput +from skyrl.train.generators.base import GeneratorOutput, TrajectoryID from skyrl.train.generators.utils import ( concatenate_generator_outputs, get_metrics_from_generator_output, + merge_stepwise_output, ) +from skyrl.train.utils.utils import validate_cfg +from tests.train.util import example_dummy_config def test_generator_output_concatenation(): @@ -116,3 +122,531 @@ def test_get_metrics_from_generator_output(): assert metrics["avg_score"] == 0.0 assert metrics["pass_at_n"] == 0.5 assert metrics["mean_positive_reward"] == 0.75 + + +# ─────────────────────────────────────────────────────────────────── +# merge_stepwise_output (prefix-aware merging) tests +# ─────────────────────────────────────────────────────────────────── + + +def _make_tid(instance_id: str, rep: int = 0) -> TrajectoryID: + return TrajectoryID(instance_id=instance_id, repetition_id=rep) + + +class TestMergeStepwiseOutput: + """CPU-only tests for ``merge_stepwise_output`` (prefix-aware merging).""" + + # ─── Core merging cases ──────────────────────────────────────── + + def test_case1_response_only_assistant(self): + """Case 1: response only contains assistant-generated tokens. + + Turn 1: prompt=[O1], response=[A1] + Turn 2: prompt=[O1, A1, O2], response=[A2] + + Merged: prompt=[O1], response=[A1, O2, A2] + obs_delta = [O2] + """ + tid = _make_tid("traj_1") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [10, 20, 30]], + "response_ids": [[20], [40, 41]], + "rewards": [[1.0], [0.0, 5.0]], + "loss_masks": [[1], [1, 1]], + "stop_reasons": ["continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": [[-0.5], [-0.3, -0.4]], + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert len(merged["response_ids"]) == 1 + assert merged["prompt_token_ids"] == [[10]] + # response = A1 + obs_delta(O2) + A2(two tokens) + assert merged["response_ids"] == [[20, 30, 40, 41]] + # loss_mask: A1=1, O2=0, A2_tok1=1, A2_tok2=1 + assert merged["loss_masks"] == [[1, 0, 1, 1]] + # logprobs: A1=-0.5, O2=0.0, A2_tok1=-0.3, A2_tok2=-0.4 + assert merged["rollout_logprobs"] == [[-0.5, 0.0, -0.3, -0.4]] + # rewards: A1=1.0, O2=0.0, A2_tok1=0.0, A2_tok2=5.0 + assert merged["rewards"] == [[1.0, 0.0, 0.0, 5.0]] + assert merged["stop_reasons"] == ["eos"] + assert merged["trajectory_ids"] == [tid] + assert merged["is_last_step"] == [True] + + def test_case2_response_contains_observation(self): + """Case 2: response contains both assistant and observation tokens. + + Turn 1: prompt=[O1], response=[A1, O2] + Turn 2: prompt=[O1, A1, O2], response=[A2] + + Merged: prompt=[O1], response=[A1, O2, A2] + obs_delta = [] (empty, since prompt+response of turn 1 == prompt of turn 2) + """ + tid = _make_tid("traj_2") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [10, 20, 30]], + "response_ids": [[20, 30], [40]], + "rewards": [[0.0, 0.0], [7.0]], + "loss_masks": [[1, 0], [1]], + "stop_reasons": ["continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": [[-0.1, -0.2], [-0.3]], + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert len(merged["response_ids"]) == 1 + assert merged["prompt_token_ids"] == [[10]] + # response = [A1, O2] + [] + [A2] = [A1, O2, A2] + assert merged["response_ids"] == [[20, 30, 40]] + assert merged["loss_masks"] == [[1, 0, 1]] + assert merged["rollout_logprobs"] == [[-0.1, -0.2, -0.3]] + assert merged["rewards"] == [[0.0, 0.0, 7.0]] + assert merged["stop_reasons"] == ["eos"] + assert merged["is_last_step"] == [True] + + def test_case3_combination(self): + """Case 3: response has obs tokens AND there's extra obs_delta in prompt. + + Turn 1: prompt=[O1], response=[A1, O2] + Turn 2: prompt=[O1, A1, O2, O2_5], response=[A2] + + Merged: prompt=[O1], response=[A1, O2, O2_5, A2] + obs_delta = [O2_5] + """ + tid = _make_tid("traj_3") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [10, 20, 30, 35]], + "response_ids": [[20, 30], [40]], + "rewards": [[0.0, 0.0], [9.0]], + "loss_masks": [[1, 0], [1]], + "stop_reasons": ["continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": [[-0.1, -0.2], [-0.5]], + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert len(merged["response_ids"]) == 1 + assert merged["prompt_token_ids"] == [[10]] + # response = [A1, O2] + obs_delta[O2_5] + [A2] + assert merged["response_ids"] == [[20, 30, 35, 40]] + assert merged["loss_masks"] == [[1, 0, 0, 1]] + assert merged["rollout_logprobs"] == [[-0.1, -0.2, 0.0, -0.5]] + assert merged["rewards"] == [[0.0, 0.0, 0.0, 9.0]] + assert merged["is_last_step"] == [True] + + # ─── Multi-turn and multi-trajectory ─────────────────────────── + + def test_three_turns(self): + """Three-turn trajectory merging (Case 1 pattern repeated). + + Turn 1: prompt=[1], response=[2] + Turn 2: prompt=[1,2,3], response=[4] + Turn 3: prompt=[1,2,3,4,5], response=[6] + + Merged: prompt=[1], response=[2, 3, 4, 5, 6] + """ + tid = _make_tid("traj_multi") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[1], [1, 2, 3], [1, 2, 3, 4, 5]], + "response_ids": [[2], [4], [6]], + "rewards": [[0.0], [0.0], [10.0]], + "loss_masks": [[1], [1], [1]], + "stop_reasons": ["continue", "continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": [[-1.0], [-2.0], [-3.0]], + "trajectory_ids": [tid, tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert len(merged["response_ids"]) == 1 + assert merged["prompt_token_ids"] == [[1]] + # resp[0]=2, obs_delta=3, resp[1]=4, obs_delta=5, resp[2]=6 + assert merged["response_ids"] == [[2, 3, 4, 5, 6]] + assert merged["loss_masks"] == [[1, 0, 1, 0, 1]] + assert merged["rollout_logprobs"] == [[-1.0, 0.0, -2.0, 0.0, -3.0]] + assert merged["rewards"] == [[0.0, 0.0, 0.0, 0.0, 10.0]] + assert merged["is_last_step"] == [True] + assert merged["stop_reasons"] == ["eos"] + + def test_multiple_trajectories(self): + """Two separate trajectories in the same batch, each with 2 turns.""" + tid_a = _make_tid("A") + tid_b = _make_tid("B") + gen_out: GeneratorOutput = { + "prompt_token_ids": [ + [10], # A turn 1 + [10, 20, 30], # A turn 2 + [100], # B turn 1 + [100, 200, 300], # B turn 2 + ], + "response_ids": [ + [20], # A turn 1 + [40], # A turn 2 + [200], # B turn 1 + [400], # B turn 2 + ], + "rewards": [[0.0], [1.0], [0.0], [2.0]], + "loss_masks": [[1], [1], [1], [1]], + "stop_reasons": ["continue", "eos", "continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid_a, tid_a, tid_b, tid_b], + "rollout_expert_indices": None, + "is_last_step": [False, True, False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert len(merged["response_ids"]) == 2 + # Trajectory A merged + assert merged["prompt_token_ids"][0] == [10] + assert merged["response_ids"][0] == [20, 30, 40] + assert merged["loss_masks"][0] == [1, 0, 1] + # Trajectory B merged + assert merged["prompt_token_ids"][1] == [100] + assert merged["response_ids"][1] == [200, 300, 400] + assert merged["loss_masks"][1] == [1, 0, 1] + assert merged["is_last_step"] == [True, True] + assert merged["stop_reasons"] == ["eos", "eos"] + + def test_mixed_trajectories_and_turns(self): + """Mix of single-turn and multi-turn trajectories in one batch.""" + tid_a = _make_tid("A") + tid_b = _make_tid("B") # single turn + tid_c = _make_tid("C") + + gen_out: GeneratorOutput = { + "prompt_token_ids": [ + [1], # A turn 1 + [1, 2, 3], # A turn 2 + [1, 2, 3, 4, 5], # A turn 3 + [50], # B single turn + [60], # C turn 1 + [60, 70, 80], # C turn 2 + ], + "response_ids": [ + [2], # A + [4], # A + [6], # A + [51], # B + [70], # C + [90], # C + ], + "rewards": [[0.0], [0.0], [10.0], [3.0], [0.0], [7.0]], + "loss_masks": [[1], [1], [1], [1], [1], [1]], + "stop_reasons": ["c", "c", "eos", "eos", "c", "eos"], + "rollout_metrics": None, + "rollout_logprobs": [[-1.0], [-2.0], [-3.0], [-4.0], [-5.0], [-6.0]], + "trajectory_ids": [tid_a, tid_a, tid_a, tid_b, tid_c, tid_c], + "rollout_expert_indices": None, + "is_last_step": [False, False, True, True, False, True], + } + + merged = merge_stepwise_output(gen_out) + + # 3 entries: A merged, B single, C merged + assert len(merged["response_ids"]) == 3 + + # A: prompt=[1], response=[2,3,4,5,6] + assert merged["prompt_token_ids"][0] == [1] + assert merged["response_ids"][0] == [2, 3, 4, 5, 6] + assert merged["loss_masks"][0] == [1, 0, 1, 0, 1] + assert merged["rollout_logprobs"][0] == [-1.0, 0.0, -2.0, 0.0, -3.0] + assert merged["rewards"][0] == [0.0, 0.0, 0.0, 0.0, 10.0] + + # B: unchanged + assert merged["prompt_token_ids"][1] == [50] + assert merged["response_ids"][1] == [51] + + # C: prompt=[60], response=[70,80,90] + assert merged["prompt_token_ids"][2] == [60] + assert merged["response_ids"][2] == [70, 80, 90] + assert merged["loss_masks"][2] == [1, 0, 1] + + assert merged["is_last_step"] == [True, True, True] + assert merged["stop_reasons"] == ["eos", "eos", "eos"] + + # ─── Edge cases / optional fields ────────────────────────────── + + def test_single_turn_passthrough(self): + """Single-turn trajectory is passed through unchanged.""" + tid = _make_tid("single") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[1, 2, 3]], + "response_ids": [[4, 5]], + "rewards": [[0.5, 0.6]], + "loss_masks": [[1, 1]], + "stop_reasons": ["eos"], + "rollout_metrics": {"some_metric": 1.0}, + "rollout_logprobs": [[-0.1, -0.2]], + "trajectory_ids": [tid], + "rollout_expert_indices": None, + "is_last_step": [True], + } + + merged = merge_stepwise_output(gen_out) + + assert merged["prompt_token_ids"] == [[1, 2, 3]] + assert merged["response_ids"] == [[4, 5]] + assert merged["loss_masks"] == [[1, 1]] + assert merged["rollout_logprobs"] == [[-0.1, -0.2]] + assert merged["rewards"] == [[0.5, 0.6]] + assert merged["is_last_step"] == [True] + # rollout_metrics is re-aggregated by concatenate_generator_outputs + assert merged["rollout_metrics"] is not None + + def test_per_trajectory_scalar_rewards(self): + """Per-trajectory scalar rewards (List[float]) are handled correctly.""" + tid = _make_tid("scalar_rew") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [10, 20, 30]], + "response_ids": [[20], [40]], + "rewards": [0.0, 5.0], # scalar per turn + "loss_masks": [[1], [1]], + "stop_reasons": None, + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert len(merged["response_ids"]) == 1 + assert merged["response_ids"] == [[20, 30, 40]] + # Scalar reward: use the last turn's value + assert merged["rewards"] == [5.0] + assert merged["rollout_logprobs"] is None + + def test_no_logprobs_no_stop_reasons(self): + """Works correctly when rollout_logprobs and stop_reasons are None.""" + tid = _make_tid("no_lp") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[1], [1, 2, 3]], + "response_ids": [[2], [4]], + "rewards": [[0.0], [1.0]], + "loss_masks": [[1], [1]], + "stop_reasons": None, + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert merged["rollout_logprobs"] is None + assert merged["stop_reasons"] is None + assert merged["response_ids"] == [[2, 3, 4]] + assert merged["loss_masks"] == [[1, 0, 1]] + + def test_prefix_mismatch_no_merge(self): + """If prefix condition fails, turns are kept separate.""" + tid = _make_tid("no_merge") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [99, 88]], # second prompt doesn't share prefix + "response_ids": [[20], [40]], + "rewards": [[0.0], [1.0]], + "loss_masks": [[1], [1]], + "stop_reasons": ["continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + # No merging happened, output has same number of entries + assert len(merged["response_ids"]) == 2 + assert merged["prompt_token_ids"] == [[10], [99, 88]] + assert merged["response_ids"] == [[20], [40]] + assert merged["is_last_step"] == [False, True] + + def test_prefix_of_prompt_plus_response_but_not_prompt_alone_no_merge(self): + """prompt[i]+response[i] is a prefix of prompt[i+1]+response[i+1] but NOT + of prompt[i+1] alone. This must NOT merge. + + Turn 1: prompt=[10], response=[20, 30] + → full = [10, 20, 30] + Turn 2: prompt=[10, 20], response=[30, 40] + → prompt[1]+response[1] = [10, 20, 30, 40] + + full ([10,20,30]) IS a prefix of prompt[1]+response[1] ([10,20,30,40]), + but is NOT a prefix of prompt[1] ([10,20]) since full is longer. + This would imply response tokens from turn 1 overlap with response tokens + of turn 2, which is malformed — we intentionally refuse to merge. + """ + tid = _make_tid("overlap") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [10, 20]], + "response_ids": [[20, 30], [30, 40]], + "rewards": [[0.0, 0.0], [1.0, 2.0]], + "loss_masks": [[1, 0], [1, 1]], + "stop_reasons": ["continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + # No merging: kept as 2 separate entries + assert len(merged["response_ids"]) == 2 + assert merged["prompt_token_ids"] == [[10], [10, 20]] + assert merged["response_ids"] == [[20, 30], [30, 40]] + assert merged["is_last_step"] == [False, True] + + def test_partial_merge_within_trajectory(self): + """4 turns where prefix breaks mid-trajectory, producing 2 merged sequences. + + Turn 1: prompt=[1], response=[2] → prefix OK for turn 2 + Turn 2: prompt=[1,2,3], response=[4] → prefix BREAKS for turn 3 + (re-tokenization changed [2,3] to [23]) + Turn 3: prompt=[1,23,4,5], response=[6] → prefix OK for turn 4 + Turn 4: prompt=[1,23,4,5,6,7], response=[8] + + Turns 1+2 merge into one sequence, turns 3+4 merge into another. + Result: 4 turns → 2 sequences. + """ + tid = _make_tid("partial") + gen_out: GeneratorOutput = { + "prompt_token_ids": [ + [1], # turn 1 + [1, 2, 3], # turn 2: prompt[0]+resp[0]=[1,2] is prefix of [1,2,3] ✓ + [1, 23, 4, 5], # turn 3: prompt[1]+resp[1]=[1,2,3,4] is NOT prefix of [1,23,4,5] ✗ + [1, 23, 4, 5, 6, 7], # turn 4: prompt[2]+resp[2]=[1,23,4,5,6] is prefix ✓ + ], + "response_ids": [ + [2], # turn 1 + [4], # turn 2 + [6], # turn 3 + [8], # turn 4 + ], + "rewards": [[0.0], [0.0], [0.0], [10.0]], + "loss_masks": [[1], [1], [1], [1]], + "stop_reasons": ["c", "c", "c", "eos"], + "rollout_metrics": None, + "rollout_logprobs": [[-1.0], [-2.0], [-3.0], [-4.0]], + "trajectory_ids": [tid, tid, tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, False, False, True], + } + + merged = merge_stepwise_output(gen_out) + + # 4 turns → 2 merged sequences + assert len(merged["response_ids"]) == 2 + + # First merged group: turns 1+2 + # prompt=[1], response=[2] + obs_delta=[3] + [4] = [2, 3, 4] + assert merged["prompt_token_ids"][0] == [1] + assert merged["response_ids"][0] == [2, 3, 4] + assert merged["loss_masks"][0] == [1, 0, 1] + assert merged["rollout_logprobs"][0] == [-1.0, 0.0, -2.0] + assert merged["rewards"][0] == [0.0, 0.0, 0.0] + assert merged["is_last_step"][0] is False + + # Second merged group: turns 3+4 + # prompt=[1,23,4,5], response=[6] + obs_delta=[7] + [8] = [6, 7, 8] + assert merged["prompt_token_ids"][1] == [1, 23, 4, 5] + assert merged["response_ids"][1] == [6, 7, 8] + assert merged["loss_masks"][1] == [1, 0, 1] + assert merged["rollout_logprobs"][1] == [-3.0, 0.0, -4.0] + assert merged["rewards"][1] == [0.0, 0.0, 10.0] + assert merged["is_last_step"][1] is True + + assert merged["stop_reasons"] == ["c", "eos"] + assert merged["trajectory_ids"] == [tid, tid] + + # ─── Assertion / validation on GeneratorOutput shape ─────────── + + def test_asserts_trajectory_ids_required(self): + """Raises when trajectory_ids is None.""" + gen_out: GeneratorOutput = { + "prompt_token_ids": [[1]], + "response_ids": [[2]], + "rewards": [[1.0]], + "loss_masks": [[1]], + "stop_reasons": None, + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": None, + "rollout_expert_indices": None, + "is_last_step": [True], + } + with pytest.raises(AssertionError, match="trajectory_ids"): + merge_stepwise_output(gen_out) + + def test_asserts_is_last_step_required(self): + """Raises when is_last_step is None.""" + tid = _make_tid("x") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[1]], + "response_ids": [[2]], + "rewards": [[1.0]], + "loss_masks": [[1]], + "stop_reasons": None, + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid], + "rollout_expert_indices": None, + "is_last_step": None, + } + with pytest.raises(TypeError): + merge_stepwise_output(gen_out) + + def test_asserts_no_expert_indices(self): + """Raises when rollout_expert_indices is present.""" + tid = _make_tid("x") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[1]], + "response_ids": [[2]], + "rewards": [[1.0]], + "loss_masks": [[1]], + "stop_reasons": None, + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid], + "rollout_expert_indices": [[[[1, 2]]]], + "is_last_step": [True], + } + with pytest.raises(AssertionError, match="rollout_expert_indices not supported"): + merge_stepwise_output(gen_out) + + # ─── Config validation ───────────────────────────────────────── + + @patch("skyrl.train.utils.utils.validate_batch_sizes", new=lambda cfg: None) + @patch("skyrl.train.utils.utils.validate_generator_cfg", new=lambda cfg: None) + def test_validate_cfg_merge_stepwise_requires_step_wise(self): + """`merge_stepwise_output=True` without `step_wise_trajectories=True` must fail validation. + + Prefix-aware merging operates on step-wise-only fields (trajectory_ids, is_last_step), so + enabling it without step-wise training would crash later with a confusing assertion inside + `merge_stepwise_output`. `validate_cfg` should reject the combination up front. + """ + cfg = example_dummy_config() + cfg.generator.merge_stepwise_output = True + cfg.generator.step_wise_trajectories = False + with pytest.raises(ValueError, match="merge_stepwise_output.*requires.*step_wise_trajectories"): + validate_cfg(cfg) diff --git a/tests/train/test_generator_postprocess.py b/tests/train/test_generator_postprocess.py index 576265e37b..079b5bbe72 100644 --- a/tests/train/test_generator_postprocess.py +++ b/tests/train/test_generator_postprocess.py @@ -58,7 +58,8 @@ def test_response_level_rewards(): "rollout_metrics": None, } - result = trainer.postprocess_generator_output(generator_output, ["uid1"]) + result, result_uids = trainer.postprocess_generator_output(generator_output, ["uid1"]) + assert result_uids == ["uid1"] # Verify conversion to per-token rewards assert result["rewards"] == [[0.0, 0.0, 1.0]] @@ -84,7 +85,8 @@ def test_response_level_rewards(): "rollout_metrics": None, } - result = trainer.postprocess_generator_output(generator_output, ["uid1", "uid2"]) + result, result_uids = trainer.postprocess_generator_output(generator_output, ["uid1", "uid2"]) + assert result_uids == ["uid1", "uid2"] # Verify conversion to per-token rewards assert result["rewards"] == [[0.0, 1.0], [0.0, 0.0, 0.5]] @@ -115,7 +117,8 @@ def test_token_level_rewards(): "rollout_metrics": None, } - result = trainer.postprocess_generator_output(generator_output, ["uid1"]) + result, result_uids = trainer.postprocess_generator_output(generator_output, ["uid1"]) + assert result_uids == ["uid1"] # Verify token-level rewards are unchanged assert result["rewards"] == per_token_rewards @@ -142,7 +145,8 @@ def test_token_level_rewards(): "rollout_metrics": None, } - result = trainer.postprocess_generator_output(generator_output, ["uid1", "uid2"]) + result, result_uids = trainer.postprocess_generator_output(generator_output, ["uid1", "uid2"]) + assert result_uids == ["uid1", "uid2"] # Verify token-level rewards are unchanged assert result["rewards"] == per_token_rewards