Skip to content

fix(worker): optimize swap_states to copy only active token prefixes#34733

Merged
njhill merged 6 commits intovllm-project:mainfrom
pjo256:swap-states-active-prefix
Mar 18, 2026
Merged

fix(worker): optimize swap_states to copy only active token prefixes#34733
njhill merged 6 commits intovllm-project:mainfrom
pjo256:swap-states-active-prefix

Conversation

@pjo256
Copy link
Copy Markdown
Contributor

@pjo256 pjo256 commented Feb 17, 2026

Purpose

This PR optimizes InputBatch.swap_states() in vllm/v1/worker/gpu_input_batch.py by swapping only the active token
prefix instead of full max_model_len rows.

Fixes #34731.

Changes

  • Compute active lengths with a new _get_active_token_count for i1 and i2
  • Swap only:
    • token_ids_cpu[..., :max_active_token_count]
    • is_token_ids[..., :max_active_token_count]

Test Plan

lm_eval 
pytest tests/v1/worker/test_gpu_input_batch.py -v

Test Result

tests/v1/worker/test_gpu_input_batch.py::test_sampling_metadata_in_input_batch[1-cuda:0] PASSED
tests/v1/worker/test_gpu_input_batch.py::test_sampling_metadata_in_input_batch[2-cuda:0] PASSED
tests/v1/worker/test_gpu_input_batch.py::test_sampling_metadata_in_input_batch[32-cuda:0] PASSED
tests/v1/worker/test_gpu_input_batch.py::test_sampling_metadata_in_input_batch[64-cuda:0] PASSED
tests/v1/worker/test_gpu_input_batch.py::test_swap_states_in_input_batch[swap_list0-32-cuda:0] PASSED

Performance

Benchmark script

Click to expand benchmark script
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0

import argparse
import gc
import random
import statistics
import time

import numpy as np
import torch

from vllm.sampling_params import SamplingParams
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch


def build_batch(
    max_num_reqs: int,
    max_model_len: int,
    active_min: int,
    active_max: int,
    seed: int,
) -> InputBatch:
    rng = random.Random(seed)
    np_rng = np.random.default_rng(seed)

    batch = InputBatch(
        max_num_reqs=max_num_reqs,
        max_model_len=max_model_len,
        max_num_batched_tokens=max_num_reqs * active_max,
        device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
        pin_memory=False,
        vocab_size=32000,
        block_sizes=[16],
        kernel_block_sizes=[16],
    )
    sampling = SamplingParams(temperature=0.0, max_tokens=16)

    for i in range(max_num_reqs):
        active_len = rng.randint(active_min, active_max)
        prompt_len = max(1, active_len - 8)
        output_len = active_len - prompt_len

        req = CachedRequestState(
            req_id=f"req-{i}",
            prompt_token_ids=np_rng.integers(
                1, 30000, size=prompt_len, dtype=np.int32
            ).tolist(),
            mm_features=[],
            sampling_params=sampling,
            generator=None,
            block_ids=([i],),
            num_computed_tokens=output_len,
            output_token_ids=np_rng.integers(
                1, 30000, size=output_len, dtype=np.int32
            ).tolist(),
        )
        idx = batch.add_request(req)

        # Keep some speculative tokens so active prefix length varies by request.
        spec_len = active_len - int(batch.num_tokens_no_spec[idx])
        if spec_len > 0:
            start = int(batch.num_tokens_no_spec[idx])
            end = start + spec_len
            spec_ids = np_rng.integers(1, 30000, size=spec_len, dtype=np.int32).tolist()
            batch.spec_token_ids[idx] = spec_ids
            batch.token_ids_cpu[idx, start:end] = spec_ids
            batch.is_token_ids[idx, start:end] = True

    return batch


def make_swaps(max_num_reqs: int, num_swaps: int, seed: int) -> list[tuple[int, int]]:
    rng = random.Random(seed + 17)
    swaps = []
    for _ in range(num_swaps):
        i1 = rng.randrange(max_num_reqs)
        i2 = rng.randrange(max_num_reqs)
        if i1 == i2:
            i2 = (i2 + 1) % max_num_reqs
        swaps.append((i1, i2))
    return swaps


def run_once(
    num_swaps: int,
    max_num_reqs: int,
    max_model_len: int,
    active_min: int,
    active_max: int,
    seed: int,
) -> float:
    batch = build_batch(
        max_num_reqs=max_num_reqs,
        max_model_len=max_model_len,
        active_min=active_min,
        active_max=active_max,
        seed=seed,
    )
    swaps = make_swaps(max_num_reqs=max_num_reqs, num_swaps=num_swaps, seed=seed)

    t0 = time.perf_counter()
    for i1, i2 in swaps:
        batch.swap_states(i1, i2)
    t1 = time.perf_counter()
    return (t1 - t0) * 1000.0


def summarize(values: list[float]) -> tuple[float, float]:
    mean = statistics.fmean(values)
    median = statistics.median(values)
    return mean, median


def main() -> None:
    parser = argparse.ArgumentParser(
        description=(
            "Microbenchmark for current InputBatch.swap_states implementation. "
            "Run this script before/after stashing code changes to compare."
        )
    )
    parser.add_argument("--branch", default="current")
    parser.add_argument("--num-swaps", type=int, default=1000)
    parser.add_argument("--max-num-reqs", type=int, default=512)
    parser.add_argument("--max-model-len", type=int, default=32768)
    parser.add_argument("--active-min", type=int, default=512)
    parser.add_argument("--active-max", type=int, default=2048)
    parser.add_argument("--repeats", type=int, default=5)
    parser.add_argument("--seed", type=int, default=7)
    args = parser.parse_args()

    # Warmup to reduce first-run noise.
    _ = run_once(
        num_swaps=min(500, args.num_swaps),
        max_num_reqs=args.max_num_reqs,
        max_model_len=args.max_model_len,
        active_min=args.active_min,
        active_max=args.active_max,
        seed=args.seed - 1,
    )

    samples_ms = []
    for r in range(args.repeats):
        gc.collect()
        ms = run_once(
            num_swaps=args.num_swaps,
            max_num_reqs=args.max_num_reqs,
            max_model_len=args.max_model_len,
            active_min=args.active_min,
            active_max=args.active_max,
            seed=args.seed + r,
        )
        samples_ms.append(ms)
        print(f"branch={args.branch} run={r+1} total_ms={ms:.2f}")

    mean_ms, median_ms = summarize(samples_ms)
    print("---- summary ----")
    print(f"branch={args.branch} mean_ms={mean_ms:.2f} median_ms={median_ms:.2f}")


if __name__ == "__main__":
    main()

Seeing ~25ms saved with this limited benchmarking script

python bench_swap_states_micro.py --branch swap-states-active-prefix --repeats 3 --num-swaps 1000 --max-num-reqs 512 --max-model-len 32768 --active-min 512 --active-max 2048
branch=main run=1 total_ms=42.73
branch=main run=2 total_ms=41.92
branch=main run=3 total_ms=40.15
---- summary ----
branch=main mean_ms=41.60 median_ms=41.92
branch=swap-states-active-prefix run=1 total_ms=17.48
branch=swap-states-active-prefix run=2 total_ms=17.18
branch=swap-states-active-prefix run=3 total_ms=16.66
---- summary ----
branch=swap-states-active-prefix mean_ms=17.10 median_ms=17.18

lm_eval

lm_eval run --model vllm --model_args pretrained=Qwen/Qwen2-1.5B-Instruct,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,enforce_eager=True,max_model_len=4096 --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 50

main

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.58|±  |0.0705|
|     |       |strict-match    |     5|exact_match|↑  | 0.58|±  |0.0705|

This PR

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.56|±  |0.0709|
|     |       |strict-match    |     5|exact_match|↑  | 0.56|±  |0.0709|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Philip Ottesen <phiott256@gmail.com>
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the v1 label Feb 17, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully optimizes the swap_states and condense methods in InputBatch by limiting the data movement to the active token prefix of each request. This change significantly reduces the overhead of reordering requests in the batch, especially when the maximum model length is large. The introduction of the _get_active_token_count helper method centralizes the logic for determining the active range of tokens, including speculative tokens. The performance benchmarks provided in the description confirm a substantial reduction in execution time for these operations. The implementation is correct and maintains consistency with the existing metadata management.

@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 26, 2026
Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! running CI

@pjo256 pjo256 requested a review from njhill as a code owner March 2, 2026 23:48
@pjo256
Copy link
Copy Markdown
Contributor Author

pjo256 commented Mar 2, 2026

@LucasWilkinson Looks like we had some intermittent CI failures. Pulling in main

@njhill njhill merged commit 0091017 into vllm-project:main Mar 18, 2026
50 checks passed
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request Mar 19, 2026
…llm-project#34733)

Signed-off-by: Philip Ottesen <phiott256@gmail.com>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request Mar 19, 2026
…llm-project#34733)

Signed-off-by: Philip Ottesen <phiott256@gmail.com>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…llm-project#34733)

Signed-off-by: Philip Ottesen <phiott256@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…llm-project#34733)

Signed-off-by: Philip Ottesen <phiott256@gmail.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance]: Improve swap_states by swapping active token prefixes

3 participants