feat: add partial_rollout recipe#96
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the Partial Rollout (PRv3) pipeline, a system designed for synchronous RL training that mitigates GPU idle time caused by long-tail response lengths through mid-generation interruption and cross-step resumption. The implementation includes a new trainer, a Ray-based prompt manager for queue coordination, specialized agent loop workers, and an extended vLLM HTTP server supporting batch cancellation. Feedback highlighted a potential key collision issue in the RolloutPromptManager when performing a DataProto.union, which could lead to a crash if overlapping keys like uid are not filtered.
| # alias before union mutates anything. | ||
| repeated = rp.batch.repeat(repeat_times=self.n, interleave=True) | ||
| repeated.meta_info = dict(repeated.meta_info) | ||
| full_batch = repeated.union(rp.gen_batch_output) |
There was a problem hiding this comment.
The repeated.union(rp.gen_batch_output) call will likely raise a KeyError or collision error because both DataProto objects contain the uid key (and potentially others like agent_name). In verl, DataProto.union is designed for disjoint sets of keys. Since rp.gen_batch_output is the result of AgentLoopWorker._postprocess, it already includes the input non-tensor fields (including uid). You should filter out common keys from one side before performing the union to avoid this crash.
# Filter out keys from gen_batch_output that are already in repeated to avoid union collisions.
# uid and other common fields are expected to be identical.
for key in list(rp.gen_batch_output.batch.keys()):
if key in repeated.batch.keys():
del rp.gen_batch_output.batch[key]
for key in list(rp.gen_batch_output.non_tensor_batch.keys()):
if key in repeated.non_tensor_batch.keys():
del rp.gen_batch_output.non_tensor_batch[key]
full_batch = repeated.union(rp.gen_batch_output)There was a problem hiding this comment.
This won't crash — DataProto.union is not a strict-disjoint-keys union. The implementation in verl/protocol.py:
union_tensor_dict(L109–122):for key in tensor_dict2.keys(): if key not in tensor_dict1.keys(): tensor_dict1[key] = tensor_dict2[key] else: assert tensor_dict1[key].equal(tensor_dict2[key]), ...
union_numpy_dict(L188–198): same pattern, using_deep_equalfor non-tensor (incl. object-dtype) overlaps.
Overlapping keys are allowed; the union only fails when the values disagree. Here the values are guaranteed equal:
RayPPOTrainer._get_gen_batch(verl/trainer/ppo/ray_trainer.py:490) does not popreward_keys = {"data_source", "reward_model", "extra_info", "uid"}— it keeps them on the originalbatchand re-adds them ontogen_batch.non_tensor_batch, so both sides share identicaluidarrays.AgentLoopWorker._postprocessis called byPRv3AgentLoopWorker._generate_sequences_for_promptwithinput_non_tensor_batch=batch.non_tensor_batch(recipeagent_loop/agent_loop.py:230–234), sogen_batch_outputcarries the same numpy object foruidasrp.batch._deep_equalreturnsTrue.
Empirically this code path runs end-to-end without a union assertion failure — see the test plan in the PR description (Qwen3-0.6B + gsm8k_tool, 58/58 steps to val-core/openai/gsm8k/acc/mean@1 = 0.6566).
If union were strict-disjoint, it would also break upstream RayPPOTrainer.fit itself, since it does the same batch.union(gen_batch_output) after generation.
Marking as not actionable.
|
@ArronHZG @mamazi0131 If anyone has time, please take a look and give it a review. |
|
1 Please describe which Verl commit it's based on.
|
|
The shell file can be placed in a separate folder. |
| from verl.workers.rollout.llm_server import LLMServerManager | ||
|
|
||
|
|
||
| class PRv3LLMServerManager(LLMServerManager): |
There was a problem hiding this comment.
PRv3 is not a good name; it could be changed.
Update to current state:
- VERL_COMMIT: f400fb76 → 8ebccd4468 (current verl HEAD).
- RECIPE_FORK_BRANCH: partial_rollout → partial_rollout-simplify (the
active work branch; the old `partial_rollout` is the PR verl-project#96 base
stuck at f0fb722).
- RECIPE_SUBMODULE_COMMIT / RECIPE_FOLDER_LAST_COMMIT: da625f2 →
1ea9f50 (latest commit on partial_rollout-simplify that touched
this folder).
- GIT_SETUP fetch arg + checkout updated accordingly.
Pins point to the commit immediately before this one; the txt file
self-update lands on top.
Co-authored-by: Claude <noreply@anthropic.com>
4de95ee to
e49a129
Compare
|
@ArronHZG Thanks for the review — both points landed in the latest revision.
|
b0d1363 to
3ea9cd1
Compare
| self.inflight_traj: int = 0 | ||
| self._slot_event: asyncio.Event = asyncio.Event() | ||
|
|
||
| async def generate_for_prompt(self, batch: DataProto) -> DataProto: |
There was a problem hiding this comment.
I think it could be further simplified. If there are issues with refactoring or inheritance here, I understand that the main code could be made more abstract.
bd91a32 to
4e72af8
Compare
APRIL-style (https://arxiv.org/pdf/2509.18521) synchronous RL with cross-step rollout interruption + resume to reclaim long-tail GPU bubbles. Aborted gens carry their conversation state across the step boundary and resume on the next step while KV cache may still be live. Components: - ray_trainer.PartialRolloutRayPPOTrainer: wires a RolloutPromptManager actor in front of the rollout pool; monkey-patches verl.trainer.ppo.ray_trainer.LLMServerManager during init_workers (upstream hardcodes the class with no FQN knob). - prompt_manager.RolloutPromptManager: Ray actor with pending/ongoing/ done queues, asyncio.Event-based pull_batch (replaces the per-step ~10k-RPC poll loop), peek-then-commit batch assembly, and LIFO pending-side push for aborted prompts so the next pull keeps KV cache locality. - llm_server.PartialRolloutLLMServerManager: swaps replica class to PartialRolloutvLLMReplica, forces get_client(fully_async=True) so every caller gets FullyLLMServerClient (its abort-then-retry loop is what makes cancel/resume invisible to AgentLoop), and fans cancel()/resume() to replicas. - vllm_rollout.PartialRolloutvLLMHttpServer / PartialRolloutvLLMReplica: Python-side _resume_event gate around generate() + abort_all_requests drain loop, because vLLM <0.12 lacks pause_generation. cancel() clears the event and loops abort+drain until inflight=0; resume() sets the event, releasing any callers parked on it. The whole layer can be deleted once verl moves to vLLM >=0.12. - agent_loop.PartialRolloutAgentLoopWorker / Manager: worker runs a persistent run_continuous loop sharing one asyncio.wait across the pull RPC, the _run_one rollout tasks, and a slot-wait sentinel. Instead of overriding upstream generate_sequences (left intact for the validation path), the worker adds generate_for_prompt: upstream's generate_sequences with the trailing asyncio.gather replaced by an asyncio.wait(FIRST_COMPLETED) loop that decrements self.inflight_traj and sets self._slot_event after each trajectory completion. The outer loop pulls the next prompt as soon as inflight_traj + n <= max_inflight_prompts * n — so a long-tail trajectory doesn't block other in-flight prompts from making room for the next pull. No tool_agent_loop fork; per-trajectory abort/retry is still absorbed inside a single FullyLLMServerClient.generate() call, so the worker is oblivious to the cancel/resume cycle. Manager wraps the trainer-side generate_sequences with resume()/cancel() so continuously-running workers don't generate across a weight update. Gated by '+async_training.partial_rollout=True' (the '+' is required because the upstream ppo_trainer config doesn't pre-declare the key); the flag toggles the retry-on-abort branch inside FullyLLMServerClient. Plus gsm8k_tool_config.yaml (recovered after upstream PR #6126), run_*.sh runner scripts for 0.6B/4B x gsm8k/dapomath x pr/baseline, and README / README_zh / REFERENCE docs. Co-authored-by: Claude
|
@wuxibin89 need review |
Summary
Adds
partial_rollout/to the recipe submodule: APRIL-style (paper) synchronous RL with cross-step rollout interruption + resume to reclaim long-tail GPU bubbles. Aborted gens carry their conversation state across the step boundary and resume on the next step while their KV cache may still be live on the rollout server.Based on upstream
verl-project/verl@8ebccd44(full pin inrecipe/partial_rollout/REQUIRED_VERL.txt).Relationship to #58
Open PR #58 (
mamazi0131:main, 2026-03-01) lands the same recipe directory and was the starting point for this work. This PR is materially different on four architectural axes:LLMServerManager/AgentLoopManagersplit (verl#6117). Current upstream separates the rollout server manager from the agent-loop manager. This recipe shipsllm_server.py(PartialRolloutLLMServerManager) so cancel/resume fan-out lives on the new server-manager surface; a small symbol-swap inray_trainer.init_workersinjects it because upstreamRayPPOTrainer.init_workershardcodesLLMServerManagerwith no FQN config knob (unlike the parallelagent_loop_manager_classknob it does have). feat: async partial rollout trainer with sample supplementation and caching #58's tree doesn't import on currentmain.Cancel/retry path absorbed inside
FullyLLMServerClient(upstream verl#5631). Current upstream'sFullyLLMServerClient.generate()already has an abort-then-retry loop — when a generate is aborted mid-flight bycancel(), it parks and resumes against the next weight version's accumulated context without ever returning to AgentLoop. This recipe gates that retry branch with+async_training.partial_rollout=Trueand forcesget_client(fully_async=True)for every caller. Net effect: the recipe'sAgentLooponly handles pull/push and trajectory-grained pull pacing (see axis 3); validation goes through upstream's untouchedAgentLoopWorker.generate_sequences. feat: async partial rollout trainer with sample supplementation and caching #58 instead returns an ABORT sentinel to the agent loop, which then re-enqueues the prompt intopending_queue— a structurally heavier path that requires a forkedtool_agent_loopfor state snapshot/restore.Trajectory-grained pull pacing.
PartialRolloutAgentLoopWorker.generate_for_promptreplaces upstream's trailingoutputs = await asyncio.gather(*tasks)with anasyncio.wait(FIRST_COMPLETED)loop that decrementsself.inflight_trajand signalsself._slot_eventafter every per-trajectory completion. Therun_continuousouter loop then pulls the next prompt as soon asinflight_traj + n <= max_inflight_prompts * n— long-tail trajectories inside one prompt don't block new prompts from entering across the budget freed by other in-flight prompts' completions. Pull RPC,_run_onetasks, and the slot-wait sentinel share a singleasyncio.wait(running)set; identity checks dispatch the three task kinds. Validation keeps the simpler upstream gather path (we addgenerate_for_promptas a new method rather than overridinggenerate_sequences). feat: async partial rollout trainer with sample supplementation and caching #58 reaches a similar trajectory-level effect viapending_queuere-enqueue +last_agent_loop_outputsnapshot/restore — heavier mechanism, requires forked agent loops.Engine-level cancel via Python
_resume_event+abort_all_requestsdrain (vLLM 0.11 stopgap). vLLM <0.12 doesn't exposepause_generation, soPartialRolloutvLLMHttpServeradds a Python-side_resume_eventgate aroundgenerate()plus aninflight-countedabort_all_requests(reset_prefix_cache=False)drain loop incancel(). Deletable once verl moves to vLLM ≥0.12. feat: async partial rollout trainer with sample supplementation and caching #58 uses per-requestasyncio.Event+Lock— every in-flight generate awaits its own cancel handle; PartialRollout substitutes one engine-core batch call.In addition this PR adds:
gsm8k_tool_config.yaml(recovered after upstream #6126 deletion ofexamples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml).run_qwen3-0.6b_gsm8k_grpo_tool{,_baseline}.shplus non-tool 0.6B variants for laptop-scale repro, underrecipe/partial_rollout/run/.README_zh.md,REFERENCE.md,REQUIRED_VERL.txt.Happy to fold these into #58 if @mamazi0131 prefers — opening separately because #58 does not apply to current
verl-project/verl@mainand the rebase is nontrivial.Test plan
Full 1-epoch chain on 2× RTX 3090, Qwen3-0.6B, gsm8k, GRPO + token-level rollout-IS,
max_response_length=4096,max_model_len=4608, batch=8, TP=1.Single-turn, PR vs baseline (completed, 934 steps each)
934 steps each, no OOM, no hang
PR
timing_s/genavg 18.5s vs baseline 26.1s — −29% gen timePR
perf/throughputavg 1060 tok/s/GPU vs baseline 851 — +24%Learning curves overlap — no learning regression
pre-commit run --all-files(ruff, ruff-format) clean.Test Result
Single-turn 934-step run on 2× RTX 3090, Qwen3-0.6B, gsm8k, GRPO + token-level rollout-IS,
max_response_length=4096, batch=8, 1 epoch. Six panels — green = baseline, pink = partial_rollout.Learning-quality panels (top-left, bottom-left, bottom-right)
These three panels exist to falsify "PR breaks the algorithm." If PR shifted training dynamics, one of these would diverge.
critic/rewards/mean— both runs climb from ~0 to ~0.8 by step ~200, then track together at 0.7–0.9 for the rest of training. No reward divergence; PR's cross-step interrupt + resume does not introduce bias into the policy gradient.response_length/mean— both rise from ~600 to ~1100–1200 over the run, near-overlapping. PR is marginally higher (matches the 1142 vs 1068 averages reported in the comparison comment), within noise.actor/entropy— both decay from ~0.4 to ~0.15 along the same trajectory. Same exploration / collapse rate.→ PR is policy-correctness-neutral. The cancel-resume mechanism doesn't perturb the optimization.
Performance panels (top-middle, top-right, bottom-middle)
These show the actual speedup.
timing_s/gen⭐ — the panel that matters most. baseline sits at ~25–35s, PR sits at ~15–20s, consistently and across the entire 934 steps. The two curves almost never cross. Spikes at multiples of 50 aretest_freq=50validation steps (validation goes through upstream's no-PR path, so both runs pay the same validation cost — those spikes overlap). Excluding warmup, PR averages 18.5s, baseline 26.1s — −29%.perf/throughput(tok/s/GPU) — mirror of gen timing. PR ~1100–1400, baseline ~800–1000, sustained gap, +30% throughput. Both curves are noisy step-to-step (batch=8 means high per-step variance — a single long-tail prompt dominates), but the bands clearly separate.timing_s/step— total step wall time. PR ~30–40s, baseline ~40–50s. Same direction as gen but smaller relative gap (≈ −16%) because non-gen phases (update_actor~12s,ref+old_log_prob~5s, etc.) are unchanged by PR. PR's win is concentrated entirely inside the gen phase; the rest is identical work.Why the chart is convincing
timing_s/gen: any single step PR ≤ baseline (modulo the shared validation spikes). System-level effect, not statistical noise.Headline
29% faster gen, 30% higher throughput, no learning-curve regression. At this scale (
max_response=4096, batch=8, long-tail driven) PR is in its design sweet spot.AI-assistance disclosure
This PR was drafted with AI assistance (Claude Opus 4.7, 1M context window). The commit carries a
Co-authored-by: Claudetrailer. The submitting human (@startju) reviewed every changed line, ran the test above, and is the accountable owner of this change end-to-end.