Skip to content

tmp: fix pp scheduling#243

Merged
rebel-jaehwang merged 5 commits intodev-0.12from
fix-pp-scheduling
Jan 13, 2026
Merged

tmp: fix pp scheduling#243
rebel-jaehwang merged 5 commits intodev-0.12from
fix-pp-scheduling

Conversation

@huijjj
Copy link
Collaborator

@huijjj huijjj commented Dec 29, 2025

🚀 Summary of Changes

In upstream vLLM v1, pipeline parallelism (PP) was implemented in a subsequent PR. At a high level, the scheduler divides up to max_num_seqs requests into pp_size groups, and requests within each group are batched and executed together. Since there are pp_size such batches, they are scheduled back-to-back so that the pipeline runs without bubbles.

One important detail is how requests are assigned to groups. Requests are not explicitly grouped in a round-robin or load-balanced manner. Instead, a request is implicitly assigned to a group based on which batch happens to be scheduled at the time the request is first admitted. Once a request is scheduled and executes a single inference step, it must pass through all pipeline stages before it can be scheduled again. As a result, requests that were initially batched together naturally continue to remain in the same group across subsequent iterations.

Because of this behavior, group sizes can become uneven. In the worst case it is possible for all requests to end up in a single group while the remaining groups are empty, effectively reducing pipeline utilization to 1 / pp_size. However, this worst-case behavior is unlikely in practice. As long as max_num_seqs is configured sufficiently large and there are enough concurrent requests to keep the pipeline filled, you can generally expect meaningful throughput improvements from PP.

If it’s still unclear, please take a close look at the codes below and trace through the execution flow.

Unfortunately, this is not the case for our vllm-rbln. Since we do not support mixed batching and we also force the prefill batch size to 1, decode requests cannot be scheduled while a prefill is in progress and they remain runnable but effectively blocked until the prefill finishes. Eventually, prefill completes (typically because the running set reaches max_num_seqs or the KV cache can no longer accommodate new requests), and then we move on to scheduling decode. At that point, all pending decode requests become schedulable at the same time, and we end up packing them into a single batch. As a result, the “worst case” on the GPU side becomes essentially inevitable for us, which is why PP utilization collapses and performance can degrade severely.

In the long run, the right solution is clearly to lift the prefill batch-size constraint and allow prefill and decode to be scheduled in the same batch. If we could do that, we wouldn’t need to be having this discussion in the first place.
Unfortunately, that does not seem feasible in the near term. So as a temporary workaround, this PR introduces the following changes:


📌 Related Issues / Tickets

  • Resolves #
  • Related to #

✅ Type of Change

  • ✨ Feature (feature)
  • 🧠 Model support (model)
  • 🧬 Core engine changes (core)
  • 🛠 Bug fix (bug-fix)
  • ⚙️ Performance improvement (perf)
  • 🔁 Refactor or code cleanup (refactor)
  • 📄 Documentation (docs)
  • ❓ Other (other): please describe

🧪 How to Test

  1. Run
    • python examples/experimental/simple_offline_bench.py
    • python examples/experimental/simple_offline_bench.py --pipeline-parallel-size 2
  2. Verify output: compare throughput
  3. Edge case tested: ...

📸 Screenshots / Logs (if applicable)


📋 Checklist

  • PR title follows Conventional Commits format
  • This PR is linked to an existing issue
  • The test method is described, and the expected result is clearly stated
  • Relevant documentation has been updated (if applicable)

💬 Notes


@huijjj huijjj self-assigned this Dec 30, 2025
@rebel-jiwoopark rebel-jiwoopark added the torch.compile torch.compile based implementation label Jan 5, 2026
@huijjj huijjj force-pushed the fix-pp-scheduling branch from 3b13bfc to ce6e62e Compare January 7, 2026 09:18
@huijjj huijjj changed the base branch from dev to dev-0.12 January 7, 2026 09:18
@huijjj huijjj force-pushed the fix-pp-scheduling branch from ce6e62e to 06f3418 Compare January 12, 2026 13:18
@huijjj huijjj changed the title Fix pp scheduling tmp: fix pp scheduling Jan 12, 2026
@huijjj huijjj marked this pull request as ready for review January 12, 2026 13:19
@rebel-jaehwang
Copy link
Contributor

With examples/experimental/simple_offline_bench.py, I have observed that --pipeline-parallel-size 2 --max-num-seqs 16 improves throughput by 1.5× over --pipeline-parallel-size 1 --max-num-seqs 8. The sub-optimal scaling probably comes from send/recv, which are not quite optimized yet. We will be improving this in the near future.

@rebel-jaehwang rebel-jaehwang merged commit 3f3ff24 into dev-0.12 Jan 13, 2026
1 of 2 checks passed
@rebel-jaehwang rebel-jaehwang deleted the fix-pp-scheduling branch January 13, 2026 08:35
rebel-jaehwang added a commit that referenced this pull request Jan 30, 2026
* fix: limit decode bs to (max num seqs // pp size)

* tmp: pad decode inputs to max_num_seqs // pp_size

* add: simple offline benchmark script

* refac: consolidate self.max_batch_size and decode_max_batch_size

* fix: clearer perf report

---------

Co-authored-by: Jaehwang Jung <jaehwang.jung@rebellions.ai>
rebel-jaehwang added a commit that referenced this pull request Jan 30, 2026
* fix: limit decode bs to (max num seqs // pp size)

* tmp: pad decode inputs to max_num_seqs // pp_size

* add: simple offline benchmark script

* refac: consolidate self.max_batch_size and decode_max_batch_size

* fix: clearer perf report

---------

Co-authored-by: Jaehwang Jung <jaehwang.jung@rebellions.ai>
rebel-jiwoopark pushed a commit that referenced this pull request Feb 4, 2026
* fix: limit decode bs to (max num seqs // pp size)

* tmp: pad decode inputs to max_num_seqs // pp_size

* add: simple offline benchmark script

* refac: consolidate self.max_batch_size and decode_max_batch_size

* fix: clearer perf report

---------

Co-authored-by: Jaehwang Jung <jaehwang.jung@rebellions.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

torch.compile torch.compile based implementation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants