-
Notifications
You must be signed in to change notification settings - Fork 316
Add prefix-aware merging for step-wise training #1479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -744,6 +744,198 @@ def _validate_step_wise_fields(generator_output: GeneratorOutput, num_responses: | |
| ) | ||
|
|
||
|
|
||
| def _is_prefix(maybe_prefix: List[int], candidate: List[int]) -> bool: | ||
| """Check if maybe_prefix is a prefix of candidate (or equal).""" | ||
| if len(maybe_prefix) > len(candidate): | ||
| return False | ||
| return maybe_prefix == candidate[: len(maybe_prefix)] | ||
|
|
||
|
|
||
| def _slice_generator_output(generator_output: GeneratorOutput, indices: List[int]) -> GeneratorOutput: | ||
| """Slice a GeneratorOutput to keep only the entries at the given indices.""" | ||
| return { | ||
| "prompt_token_ids": [generator_output["prompt_token_ids"][i] for i in indices], | ||
| "response_ids": [generator_output["response_ids"][i] for i in indices], | ||
| "rewards": [generator_output["rewards"][i] for i in indices], | ||
| "loss_masks": [generator_output["loss_masks"][i] for i in indices], | ||
| "stop_reasons": ( | ||
| [generator_output["stop_reasons"][i] for i in indices] | ||
| if generator_output.get("stop_reasons") is not None | ||
| else None | ||
| ), | ||
| "rollout_metrics": generator_output.get("rollout_metrics"), | ||
| "rollout_logprobs": ( | ||
| [generator_output["rollout_logprobs"][i] for i in indices] | ||
| if generator_output.get("rollout_logprobs") is not None | ||
| else None | ||
| ), | ||
| "trajectory_ids": ( | ||
| [generator_output["trajectory_ids"][i] for i in indices] | ||
| if generator_output.get("trajectory_ids") is not None | ||
| else None | ||
| ), | ||
| "rollout_expert_indices": ( | ||
| [generator_output["rollout_expert_indices"][i] for i in indices] | ||
| if generator_output.get("rollout_expert_indices") is not None | ||
| else None | ||
| ), | ||
| "is_last_step": ( | ||
| [generator_output["is_last_step"][i] for i in indices] | ||
| if generator_output.get("is_last_step") is not None | ||
| else None | ||
| ), | ||
| } | ||
|
|
||
|
|
||
| def _merge_single_trajectory(gen_out: GeneratorOutput) -> GeneratorOutput: | ||
| """Greedily merge turns of a single trajectory using prefix matching. | ||
|
|
||
| Takes a GeneratorOutput whose entries all belong to the same trajectory | ||
| and returns a merged GeneratorOutput (potentially fewer entries). | ||
| """ | ||
| n = len(gen_out["response_ids"]) | ||
| token_level_rewards = n > 0 and isinstance(gen_out["rewards"][0], list) | ||
| has_logprobs = gen_out.get("rollout_logprobs") is not None | ||
| has_stop_reasons = gen_out.get("stop_reasons") is not None | ||
|
|
||
| # Per-field output accumulators (lists of per-entry values) | ||
| out_prompt_ids: List[List[int]] = [] | ||
| out_response_ids: List[List[int]] = [] | ||
| out_loss_masks: List[List[int]] = [] | ||
| out_logprobs: Optional[List[List[float]]] = [] if has_logprobs else None | ||
| out_rewards: list = [] | ||
| out_stop_reasons: Optional[List[str]] = [] if has_stop_reasons else None | ||
| out_trajectory_ids: list = [] | ||
| out_is_last_step: List[bool] = [] | ||
|
|
||
| # Accumulator for the current merge group | ||
| acc_prompt: List[int] = list(gen_out["prompt_token_ids"][0]) | ||
| acc_response: List[int] = list(gen_out["response_ids"][0]) | ||
| acc_loss_mask: List[int] = list(gen_out["loss_masks"][0]) | ||
| acc_logprobs: Optional[List[float]] = list(gen_out["rollout_logprobs"][0]) if has_logprobs else None | ||
| acc_rewards_tokens: Optional[List[float]] = list(gen_out["rewards"][0]) if token_level_rewards else None | ||
| last = 0 | ||
|
Comment on lines
+811
to
+817
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic to initialize the accumulators for a merge group (lines 812-817) is nearly identical to the logic for resetting them upon a prefix mismatch (lines 853-858). To reduce code duplication and improve maintainability, consider extracting this logic into a nested helper function that can be called in both places. |
||
|
|
||
| def flush(): | ||
| nonlocal acc_prompt, acc_response, acc_loss_mask, acc_logprobs, acc_rewards_tokens, last | ||
| out_prompt_ids.append(acc_prompt) | ||
| out_response_ids.append(acc_response) | ||
| out_loss_masks.append(acc_loss_mask) | ||
| if has_logprobs: | ||
| out_logprobs.append(acc_logprobs) | ||
| out_rewards.append(acc_rewards_tokens if token_level_rewards else gen_out["rewards"][last]) | ||
| if has_stop_reasons: | ||
| out_stop_reasons.append(gen_out["stop_reasons"][last]) | ||
| out_trajectory_ids.append(gen_out["trajectory_ids"][last]) | ||
| out_is_last_step.append(gen_out["is_last_step"][last]) | ||
|
|
||
| prefix_failures_logged = 0 | ||
| for i in range(1, n): | ||
| full_merged = acc_prompt + acc_response | ||
|
|
||
| if not _is_prefix(full_merged, list(gen_out["prompt_token_ids"][i])): | ||
| if prefix_failures_logged < 3: | ||
| next_prompt = list(gen_out["prompt_token_ids"][i]) | ||
| # Find the first divergence point | ||
| diverge_idx = next( | ||
| (j for j in range(min(len(full_merged), len(next_prompt))) if full_merged[j] != next_prompt[j]), | ||
| min(len(full_merged), len(next_prompt)), | ||
| ) | ||
| logger.warning( | ||
| f"Prefix mismatch at turn {i}: " | ||
| f"len(full_merged)={len(full_merged)}, len(next_prompt)={len(next_prompt)}, " | ||
| f"first_diverge_idx={diverge_idx}, " | ||
| f"full_merged[{diverge_idx-2}:{diverge_idx+3}]={full_merged[max(0,diverge_idx-2):diverge_idx+3]}, " | ||
| f"next_prompt[{diverge_idx-2}:{diverge_idx+3}]={next_prompt[max(0,diverge_idx-2):diverge_idx+3]}" | ||
| ) | ||
| prefix_failures_logged += 1 | ||
| flush() | ||
| acc_prompt = list(gen_out["prompt_token_ids"][i]) | ||
| acc_response = list(gen_out["response_ids"][i]) | ||
| acc_loss_mask = list(gen_out["loss_masks"][i]) | ||
| acc_logprobs = list(gen_out["rollout_logprobs"][i]) if has_logprobs else None | ||
| acc_rewards_tokens = list(gen_out["rewards"][i]) if token_level_rewards else None | ||
| last = i | ||
| continue | ||
|
|
||
| obs_delta = list(gen_out["prompt_token_ids"][i][len(full_merged):]) | ||
|
|
||
| acc_response.extend(obs_delta) | ||
| acc_loss_mask.extend([0] * len(obs_delta)) | ||
| if acc_logprobs is not None: | ||
| acc_logprobs.extend([0.0] * len(obs_delta)) | ||
| if acc_rewards_tokens is not None: | ||
| acc_rewards_tokens.extend([0.0] * len(obs_delta)) | ||
|
|
||
| acc_response.extend(gen_out["response_ids"][i]) | ||
| acc_loss_mask.extend(gen_out["loss_masks"][i]) | ||
| if acc_logprobs is not None: | ||
| acc_logprobs.extend(gen_out["rollout_logprobs"][i]) | ||
| if acc_rewards_tokens is not None: | ||
| acc_rewards_tokens.extend(list(gen_out["rewards"][i])) | ||
|
|
||
| last = i | ||
|
|
||
| flush() | ||
|
|
||
| return { | ||
| "prompt_token_ids": out_prompt_ids, | ||
| "response_ids": out_response_ids, | ||
| "rewards": out_rewards, | ||
| "loss_masks": out_loss_masks, | ||
| "stop_reasons": out_stop_reasons, | ||
| "rollout_metrics": gen_out.get("rollout_metrics"), | ||
| "rollout_logprobs": out_logprobs, | ||
| "trajectory_ids": out_trajectory_ids, | ||
| "rollout_expert_indices": None, | ||
| "is_last_step": out_is_last_step, | ||
| } | ||
|
|
||
|
|
||
| def merge_stepwise_output(generator_output: GeneratorOutput) -> GeneratorOutput: | ||
| """Merge step-wise GeneratorOutput entries using prefix-aware merging. | ||
|
|
||
| For consecutive turns within the same trajectory where | ||
| prompt[i] + response[i] is a prefix of prompt[i+1], | ||
| merges them into a single entry with: | ||
| - prompt from the first turn in the merge group | ||
| - response tokens concatenated with observation deltas (obs_delta) in between | ||
| - Per-token fields (loss_masks, rewards, logprobs) concatenated, with default | ||
| values (0) for obs_delta positions | ||
| - Per-turn fields (stop_reason, is_last_step, trajectory_id) taken from the | ||
| last turn in the merge group | ||
|
|
||
| When the prefix condition fails between two consecutive turns, the current | ||
| merge group is flushed and a new group starts (greedy merging). | ||
|
|
||
| Args: | ||
| generator_output: Step-wise GeneratorOutput with trajectory_ids and is_last_step. | ||
|
|
||
| Returns: | ||
| Merged GeneratorOutput with one entry per merged group. | ||
| """ | ||
| num_samples = len(generator_output["response_ids"]) | ||
| # step_wise=True validates trajectory_ids, is_last_step, and contiguous ordering | ||
| validate_generator_output(num_samples, generator_output, step_wise=True) | ||
| assert generator_output.get("rollout_expert_indices") is None, ( | ||
| "rollout_expert_indices not supported for prefix-aware merging" | ||
| ) | ||
|
|
||
| # Split into per-trajectory GeneratorOutputs using is_last_step boundaries | ||
| # (contiguity is guaranteed by validate_generator_output with step_wise=True) | ||
| is_last_step = generator_output["is_last_step"] | ||
| trajectory_slices: List[GeneratorOutput] = [] | ||
| start = 0 | ||
| for i in range(num_samples): | ||
| if is_last_step[i]: | ||
| trajectory_slices.append(_slice_generator_output(generator_output, list(range(start, i + 1)))) | ||
| start = i + 1 | ||
|
|
||
| merged_slices = [_merge_single_trajectory(s) for s in trajectory_slices] | ||
| # concatenate_generator_outputs re-aggregates rollout_metrics and validates | ||
| return concatenate_generator_outputs(merged_slices) | ||
|
|
||
|
|
||
| def build_dataloader( | ||
| cfg: SkyRLTrainConfig, dataset: PromptDataset, is_train=True, is_fully_async=False | ||
| ) -> StatefulDataLoader: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔴
pad_batchusesmax()instead oflcm()for step-wise pad target, breaking divisibility guaranteeWhen
step_wise_trajectories=True,pad_targetis computed asmax(policy_mini_batch_size * n, critic_mini_batch_size * n, dp_size). Padding the batch to a multiple ofmax(A, B, C)does not guarantee divisibility by all three values — onlylcm(A, B, C)would. This causesstage_chunksatskyrl/backends/skyrl_train/distributed/dispatch.py:190-192to assert-fail (len(data) % mini_batch_size == 0) whenpolicy_mini_batch_size != critic_mini_batch_size.Concrete failure example
With
policy_mini_batch_size=256, critic_mini_batch_size=384, n_samples=5:pad_target = max(1280, 1920, dp_size) = 192019201920 % 1280 = 640 ≠ 0→ policystage_chunksassertsAdditionally,
critic_mini_batch_sizeis unconditionally included even whenself.has_criticis False, which can unnecessarily inflatepad_targetabovepolicy_mini_batch_size * nand break divisibility for the policy training step.Prompt for agents
Was this helpful? React with 👍 or 👎 to provide feedback.