Skip to content

feat: add partial_rollout recipe#96

Open
startju wants to merge 1 commit into
verl-project:mainfrom
startju:partial_rollout
Open

feat: add partial_rollout recipe#96
startju wants to merge 1 commit into
verl-project:mainfrom
startju:partial_rollout

Conversation

@startju
Copy link
Copy Markdown

@startju startju commented May 6, 2026

Summary

Adds partial_rollout/ to the recipe submodule: APRIL-style (paper) synchronous RL with cross-step rollout interruption + resume to reclaim long-tail GPU bubbles. Aborted gens carry their conversation state across the step boundary and resume on the next step while their KV cache may still be live on the rollout server.

Based on upstream verl-project/verl@8ebccd44 (full pin in recipe/partial_rollout/REQUIRED_VERL.txt).

Relationship to #58

Open PR #58 (mamazi0131:main, 2026-03-01) lands the same recipe directory and was the starting point for this work. This PR is materially different on four architectural axes:

  1. LLMServerManager / AgentLoopManager split (verl#6117). Current upstream separates the rollout server manager from the agent-loop manager. This recipe ships llm_server.py (PartialRolloutLLMServerManager) so cancel/resume fan-out lives on the new server-manager surface; a small symbol-swap in ray_trainer.init_workers injects it because upstream RayPPOTrainer.init_workers hardcodes LLMServerManager with no FQN config knob (unlike the parallel agent_loop_manager_class knob it does have). feat: async partial rollout trainer with sample supplementation and caching #58's tree doesn't import on current main.

  2. Cancel/retry path absorbed inside FullyLLMServerClient (upstream verl#5631). Current upstream's FullyLLMServerClient.generate() already has an abort-then-retry loop — when a generate is aborted mid-flight by cancel(), it parks and resumes against the next weight version's accumulated context without ever returning to AgentLoop. This recipe gates that retry branch with +async_training.partial_rollout=True and forces get_client(fully_async=True) for every caller. Net effect: the recipe's AgentLoop only handles pull/push and trajectory-grained pull pacing (see axis 3); validation goes through upstream's untouched AgentLoopWorker.generate_sequences. feat: async partial rollout trainer with sample supplementation and caching #58 instead returns an ABORT sentinel to the agent loop, which then re-enqueues the prompt into pending_queue — a structurally heavier path that requires a forked tool_agent_loop for state snapshot/restore.

  3. Trajectory-grained pull pacing. PartialRolloutAgentLoopWorker.generate_for_prompt replaces upstream's trailing outputs = await asyncio.gather(*tasks) with an asyncio.wait(FIRST_COMPLETED) loop that decrements self.inflight_traj and signals self._slot_event after every per-trajectory completion. The run_continuous outer loop then pulls the next prompt as soon as inflight_traj + n <= max_inflight_prompts * n — long-tail trajectories inside one prompt don't block new prompts from entering across the budget freed by other in-flight prompts' completions. Pull RPC, _run_one tasks, and the slot-wait sentinel share a single asyncio.wait(running) set; identity checks dispatch the three task kinds. Validation keeps the simpler upstream gather path (we add generate_for_prompt as a new method rather than overriding generate_sequences). feat: async partial rollout trainer with sample supplementation and caching #58 reaches a similar trajectory-level effect via pending_queue re-enqueue + last_agent_loop_output snapshot/restore — heavier mechanism, requires forked agent loops.

  4. Engine-level cancel via Python _resume_event + abort_all_requests drain (vLLM 0.11 stopgap). vLLM <0.12 doesn't expose pause_generation, so PartialRolloutvLLMHttpServer adds a Python-side _resume_event gate around generate() plus an inflight-counted abort_all_requests(reset_prefix_cache=False) drain loop in cancel(). Deletable once verl moves to vLLM ≥0.12. feat: async partial rollout trainer with sample supplementation and caching #58 uses per-request asyncio.Event + Lock — every in-flight generate awaits its own cancel handle; PartialRollout substitutes one engine-core batch call.

In addition this PR adds:

  • gsm8k_tool_config.yaml (recovered after upstream #6126 deletion of examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml).
  • run_qwen3-0.6b_gsm8k_grpo_tool{,_baseline}.sh plus non-tool 0.6B variants for laptop-scale repro, under recipe/partial_rollout/run/.
  • README_zh.md, REFERENCE.md, REQUIRED_VERL.txt.

Happy to fold these into #58 if @mamazi0131 prefers — opening separately because #58 does not apply to current verl-project/verl@main and the rebase is nontrivial.

Test plan

Full 1-epoch chain on 2× RTX 3090, Qwen3-0.6B, gsm8k, GRPO + token-level rollout-IS, max_response_length=4096, max_model_len=4608, batch=8, TP=1.

Single-turn, PR vs baseline (completed, 934 steps each)

bash recipe/partial_rollout/run/run_qwen3-0.6b_gsm8k_grpo.sh
bash recipe/partial_rollout/run/run_qwen3-0.6b_gsm8k_grpo_baseline.sh
  • 934 steps each, no OOM, no hang

  • PR timing_s/gen avg 18.5s vs baseline 26.1s−29% gen time

  • PR perf/throughput avg 1060 tok/s/GPU vs baseline 851+24%

  • Learning curves overlap — no learning regression

  • pre-commit run --all-files (ruff, ruff-format) clean.

Test Result

image

Single-turn 934-step run on 2× RTX 3090, Qwen3-0.6B, gsm8k, GRPO + token-level rollout-IS, max_response_length=4096, batch=8, 1 epoch. Six panels — green = baseline, pink = partial_rollout.

Learning-quality panels (top-left, bottom-left, bottom-right)

These three panels exist to falsify "PR breaks the algorithm." If PR shifted training dynamics, one of these would diverge.

  • critic/rewards/mean — both runs climb from ~0 to ~0.8 by step ~200, then track together at 0.7–0.9 for the rest of training. No reward divergence; PR's cross-step interrupt + resume does not introduce bias into the policy gradient.
  • response_length/mean — both rise from ~600 to ~1100–1200 over the run, near-overlapping. PR is marginally higher (matches the 1142 vs 1068 averages reported in the comparison comment), within noise.
  • actor/entropy — both decay from ~0.4 to ~0.15 along the same trajectory. Same exploration / collapse rate.

→ PR is policy-correctness-neutral. The cancel-resume mechanism doesn't perturb the optimization.

Performance panels (top-middle, top-right, bottom-middle)

These show the actual speedup.

  • timing_s/gen ⭐ — the panel that matters most. baseline sits at ~25–35s, PR sits at ~15–20s, consistently and across the entire 934 steps. The two curves almost never cross. Spikes at multiples of 50 are test_freq=50 validation steps (validation goes through upstream's no-PR path, so both runs pay the same validation cost — those spikes overlap). Excluding warmup, PR averages 18.5s, baseline 26.1s — −29%.
  • perf/throughput (tok/s/GPU) — mirror of gen timing. PR ~1100–1400, baseline ~800–1000, sustained gap, +30% throughput. Both curves are noisy step-to-step (batch=8 means high per-step variance — a single long-tail prompt dominates), but the bands clearly separate.
  • timing_s/step — total step wall time. PR ~30–40s, baseline ~40–50s. Same direction as gen but smaller relative gap (≈ −16%) because non-gen phases (update_actor ~12s, ref + old_log_prob ~5s, etc.) are unchanged by PR. PR's win is concentrated entirely inside the gen phase; the rest is identical work.

Why the chart is convincing

  1. Sustained, not warmup-bounded: the gap shows up by step 5 and stays for 900 more steps. Not an outlier of a particular batch.
  2. Two curves never cross on timing_s/gen: any single step PR ≤ baseline (modulo the shared validation spikes). System-level effect, not statistical noise.
  3. Learning curves overlap pixel-for-pixel: the speedup is not "PR took shortcuts and generated less / worse." Reward, length, entropy match.

Headline

29% faster gen, 30% higher throughput, no learning-curve regression. At this scale (max_response=4096, batch=8, long-tail driven) PR is in its design sweet spot.

AI-assistance disclosure

This PR was drafted with AI assistance (Claude Opus 4.7, 1M context window). The commit carries a Co-authored-by: Claude trailer. The submitting human (@startju) reviewed every changed line, ran the test above, and is the accountable owner of this change end-to-end.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Partial Rollout (PRv3) pipeline, a system designed for synchronous RL training that mitigates GPU idle time caused by long-tail response lengths through mid-generation interruption and cross-step resumption. The implementation includes a new trainer, a Ray-based prompt manager for queue coordination, specialized agent loop workers, and an extended vLLM HTTP server supporting batch cancellation. Feedback highlighted a potential key collision issue in the RolloutPromptManager when performing a DataProto.union, which could lead to a crash if overlapping keys like uid are not filtered.

# alias before union mutates anything.
repeated = rp.batch.repeat(repeat_times=self.n, interleave=True)
repeated.meta_info = dict(repeated.meta_info)
full_batch = repeated.union(rp.gen_batch_output)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The repeated.union(rp.gen_batch_output) call will likely raise a KeyError or collision error because both DataProto objects contain the uid key (and potentially others like agent_name). In verl, DataProto.union is designed for disjoint sets of keys. Since rp.gen_batch_output is the result of AgentLoopWorker._postprocess, it already includes the input non-tensor fields (including uid). You should filter out common keys from one side before performing the union to avoid this crash.

            # Filter out keys from gen_batch_output that are already in repeated to avoid union collisions.
            # uid and other common fields are expected to be identical.
            for key in list(rp.gen_batch_output.batch.keys()):
                if key in repeated.batch.keys():
                    del rp.gen_batch_output.batch[key]
            for key in list(rp.gen_batch_output.non_tensor_batch.keys()):
                if key in repeated.non_tensor_batch.keys():
                    del rp.gen_batch_output.non_tensor_batch[key]
            full_batch = repeated.union(rp.gen_batch_output)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't crash — DataProto.union is not a strict-disjoint-keys union. The implementation in verl/protocol.py:

  • union_tensor_dict (L109–122):
    for key in tensor_dict2.keys():
        if key not in tensor_dict1.keys():
            tensor_dict1[key] = tensor_dict2[key]
        else:
            assert tensor_dict1[key].equal(tensor_dict2[key]), ...
  • union_numpy_dict (L188–198): same pattern, using _deep_equal for non-tensor (incl. object-dtype) overlaps.

Overlapping keys are allowed; the union only fails when the values disagree. Here the values are guaranteed equal:

  • RayPPOTrainer._get_gen_batch (verl/trainer/ppo/ray_trainer.py:490) does not pop reward_keys = {"data_source", "reward_model", "extra_info", "uid"} — it keeps them on the original batch and re-adds them onto gen_batch.non_tensor_batch, so both sides share identical uid arrays.
  • AgentLoopWorker._postprocess is called by PRv3AgentLoopWorker._generate_sequences_for_prompt with input_non_tensor_batch=batch.non_tensor_batch (recipe agent_loop/agent_loop.py:230–234), so gen_batch_output carries the same numpy object for uid as rp.batch. _deep_equal returns True.

Empirically this code path runs end-to-end without a union assertion failure — see the test plan in the PR description (Qwen3-0.6B + gsm8k_tool, 58/58 steps to val-core/openai/gsm8k/acc/mean@1 = 0.6566).

If union were strict-disjoint, it would also break upstream RayPPOTrainer.fit itself, since it does the same batch.union(gen_batch_output) after generation.

Marking as not actionable.

@startju
Copy link
Copy Markdown
Author

startju commented May 6, 2026

@ArronHZG @mamazi0131 If anyone has time, please take a look and give it a review.

@startju startju force-pushed the partial_rollout branch from 1162816 to f0fb722 Compare May 11, 2026 02:20
@ArronHZG
Copy link
Copy Markdown

1 Please describe which Verl commit it's based on.

  1. Does AgentLoop need modification? You can take a look at the logic for automatic partial and resume operations.
    https://github.com/verl-project/verl/blob/main/verl/workers/rollout/llm_server.py#L160

@ArronHZG
Copy link
Copy Markdown

The shell file can be placed in a separate folder.

Comment thread partial_rollout/llm_server.py Outdated
from verl.workers.rollout.llm_server import LLMServerManager


class PRv3LLMServerManager(LLMServerManager):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PRv3 is not a good name; it could be changed.

startju added a commit to startju/verl-recipe that referenced this pull request May 11, 2026
Update to current state:
  - VERL_COMMIT: f400fb76 → 8ebccd4468 (current verl HEAD).
  - RECIPE_FORK_BRANCH: partial_rollout → partial_rollout-simplify (the
    active work branch; the old `partial_rollout` is the PR verl-project#96 base
    stuck at f0fb722).
  - RECIPE_SUBMODULE_COMMIT / RECIPE_FOLDER_LAST_COMMIT: da625f21ea9f50 (latest commit on partial_rollout-simplify that touched
    this folder).
  - GIT_SETUP fetch arg + checkout updated accordingly.

Pins point to the commit immediately before this one; the txt file
self-update lands on top.

Co-authored-by: Claude <noreply@anthropic.com>
@startju startju force-pushed the partial_rollout branch 7 times, most recently from 4de95ee to e49a129 Compare May 11, 2026 15:11
@startju startju changed the title feat: add partial_rollout recipe (PRv3, current main-compatible) feat: add partial_rollout recipe May 11, 2026
@startju
Copy link
Copy Markdown
Author

startju commented May 11, 2026

@ArronHZG Thanks for the review — both points landed in the latest revision.

  1. Verl commit basis: verl-project/verl@8ebccd44 (same commit used to generate the chart in the PR body). Full pin (incl. pip install / git checkout recipe) in recipe/partial_rollout/REQUIRED_VERL.txt; we re-pin on every recipe-side amend.

  2. AgentLoop modification — great catch. That link (FullyLLMServerClient.generate() at llm_server.py:L160) is exactly the path we settled on. The recipe no longer forks AgentLoop logic at all: PartialRolloutAgentLoopWorker.run_continuous is just the prompt-manager pull/push loop, and per-prompt rollout delegates to upstream's unchanged AgentLoopWorker.generate_sequences. The abort/retry cycle is fully absorbed inside one FullyLLMServerClient.generate() call — gated by +async_training.partial_rollout=True with get_client(fully_async=True) (both forced by PartialRolloutLLMServerManager). The PR body's "Relationship to feat: async partial rollout trainer with sample supplementation and caching #58" section has been rewritten around this; the earlier tool_agent_loop fork (with _restore_agent_data etc.) has been removed entirely.

  3. Shell files in a separate folder: Done — all run_*.sh and gsm8k_tool_config.yaml live under recipe/partial_rollout/run/.

@startju startju force-pushed the partial_rollout branch 2 times, most recently from b0d1363 to 3ea9cd1 Compare May 11, 2026 23:37
self.inflight_traj: int = 0
self._slot_event: asyncio.Event = asyncio.Event()

async def generate_for_prompt(self, batch: DataProto) -> DataProto:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be further simplified. If there are issues with refactoring or inheritance here, I understand that the main code could be made more abstract.

@startju startju force-pushed the partial_rollout branch 5 times, most recently from bd91a32 to 4e72af8 Compare May 18, 2026 06:58
APRIL-style (https://arxiv.org/pdf/2509.18521) synchronous RL with
cross-step rollout interruption + resume to reclaim long-tail GPU
bubbles. Aborted gens carry their conversation state across the step
boundary and resume on the next step while KV cache may still be live.

Components:
- ray_trainer.PartialRolloutRayPPOTrainer: wires a RolloutPromptManager
  actor in front of the rollout pool; monkey-patches
  verl.trainer.ppo.ray_trainer.LLMServerManager during init_workers
  (upstream hardcodes the class with no FQN knob).
- prompt_manager.RolloutPromptManager: Ray actor with pending/ongoing/
  done queues, asyncio.Event-based pull_batch (replaces the per-step
  ~10k-RPC poll loop), peek-then-commit batch assembly, and LIFO
  pending-side push for aborted prompts so the next pull keeps KV
  cache locality.
- llm_server.PartialRolloutLLMServerManager: swaps replica class to
  PartialRolloutvLLMReplica, forces get_client(fully_async=True) so
  every caller gets FullyLLMServerClient (its abort-then-retry loop
  is what makes cancel/resume invisible to AgentLoop), and fans
  cancel()/resume() to replicas.
- vllm_rollout.PartialRolloutvLLMHttpServer / PartialRolloutvLLMReplica:
  Python-side _resume_event gate around generate() + abort_all_requests
  drain loop, because vLLM <0.12 lacks pause_generation. cancel() clears
  the event and loops abort+drain until inflight=0; resume() sets the
  event, releasing any callers parked on it. The whole layer can be
  deleted once verl moves to vLLM >=0.12.
- agent_loop.PartialRolloutAgentLoopWorker / Manager: worker runs a
  persistent run_continuous loop sharing one asyncio.wait across the
  pull RPC, the _run_one rollout tasks, and a slot-wait sentinel.
  Instead of overriding upstream generate_sequences (left intact for
  the validation path), the worker adds generate_for_prompt: upstream's
  generate_sequences with the trailing asyncio.gather replaced by an
  asyncio.wait(FIRST_COMPLETED) loop that decrements
  self.inflight_traj and sets self._slot_event after each trajectory
  completion. The outer loop pulls the next prompt as soon as
  inflight_traj + n <= max_inflight_prompts * n — so a long-tail
  trajectory doesn't block other in-flight prompts from making room
  for the next pull. No tool_agent_loop fork; per-trajectory
  abort/retry is still absorbed inside a single
  FullyLLMServerClient.generate() call, so the worker is oblivious to
  the cancel/resume cycle. Manager wraps the trainer-side
  generate_sequences with resume()/cancel() so continuously-running
  workers don't generate across a weight update.

Gated by '+async_training.partial_rollout=True' (the '+' is required
because the upstream ppo_trainer config doesn't pre-declare the key);
the flag toggles the retry-on-abort branch inside FullyLLMServerClient.

Plus gsm8k_tool_config.yaml (recovered after upstream PR #6126),
run_*.sh runner scripts for 0.6B/4B x gsm8k/dapomath x pr/baseline,
and README / README_zh / REFERENCE docs.

Co-authored-by: Claude
@startju startju force-pushed the partial_rollout branch from 4e72af8 to 7cabb86 Compare May 18, 2026 07:09
@ArronHZG
Copy link
Copy Markdown

@wuxibin89 need review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants