Skip to content

[Fix] Misc Fixes in ViT CUDA Graph#38040

Open
b-mu wants to merge 5 commits intovllm-project:mainfrom
CentML:bmu/vit-full-cudagraph-fix-auto-infer-invariant
Open

[Fix] Misc Fixes in ViT CUDA Graph#38040
b-mu wants to merge 5 commits intovllm-project:mainfrom
CentML:bmu/vit-full-cudagraph-fix-auto-infer-invariant

Conversation

@b-mu
Copy link
Copy Markdown
Contributor

@b-mu b-mu commented Mar 24, 2026

Purpose

  1. Previously max_batch_size = max_budget // min_budget could exceed min_budget, causing prepare_encoder_cudagraph_capture_inputs to compute per_image_output = token_budget // max_batch_size = 0 for small budgets, leading to a reshape crash on empty tensors in Qwen3_VisionPatchEmbed.forward. Fixed by capping to min(max_budget // min_budget, min_budget) if both budgets and max batch size are auto-inferred. For the paths where either budget or max batch size are provided by the user, we adjust the other (i.e. the one that is auto-inferred) to satisfy the invariant: max_batch_size <= min_budget.
  2. Use ceiling so that when token_budget is not divisible by max_batch_size, input buffer is not under-allocated.
  3. Disable CUDA graph when pruning is enabled and for images in Qwen3-VL.

Test Plan

4×GB200 NVLink (TP=4, ViT DP=4), Qwen3-VL-32B-Instruct, random-mm dataset (synthetic)

  • Eager (baseline):
  vllm bench mm-processor \
    --model Qwen/Qwen3-VL-32B-Instruct \
    --dataset-name random-mm \
    --random-mm-base-items-per-request 20 \
    --random-mm-num-mm-items-range-ratio 0.5 \
    --random-mm-bucket-config '{"(224,224,1)": 0.2, "(336,336,1)": 0.3, "(448,448,1)": 0.2, "(672,672,1)": 0.2,
  "(1008,1008,1)": 0.1}' \
    --num-prompts 1000 --num-warmups 200 \
    --max-model-len 16384 --dtype bfloat16 --seed 42 \
    --mm-encoder-attn-backend FLASH_ATTN \
    --tensor-parallel-size 4 --mm-encoder-tp-mode data
  • ViT full CUDA Graph:
  vllm bench mm-processor \
    --model Qwen/Qwen3-VL-32B-Instruct \
    --dataset-name random-mm \
    --random-mm-base-items-per-request 20 \
    --random-mm-num-mm-items-range-ratio 0.5 \
    --random-mm-bucket-config '{"(224,224,1)": 0.2, "(336,336,1)": 0.3, "(448,448,1)": 0.2, "(672,672,1)": 0.2,
  "(1008,1008,1)": 0.1}' \
    --num-prompts 1000 --num-warmups 200 \
    --max-model-len 16384 --dtype bfloat16 --seed 42 \
    --mm-encoder-attn-backend FLASH_ATTN \
    --tensor-parallel-size 4 --mm-encoder-tp-mode data \
    --compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_max_frames_per_batch": 1}'

Test Result

Encoder Forward Latency (mean):

Config FA4 BF16 FLASHINFER FP8
Eager (baseline) 38.4 ms 39.5 ms
ViT full CUDA Graph 29.7 ms (+22.7%) 31.5 ms (+20.2%)

Encoder Forward Latency (median):

Config FA4 BF16 FLASHINFER FP8
Eager (baseline) 25.4 ms 29.8 ms
ViT full CUDA Graph 23.9 ms (+5.8%) 25.3 ms (+15.1%)

Encoder Forward Latency (P99):

Config FA4 BF16 FLASHINFER FP8
Eager (baseline) 194.0 ms 244.4ms
ViT full CUDA Graph 155.6 ms (+19.8%) 140.3ms (+42.6%)

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the initialization logic for EncoderCudaGraphManager to enforce an invariant that max_batch_size must be less than or equal to the smallest token budget. This prevents potential reshape crashes during CUDA graph capture. The changes include adding validation for user-specified configurations and ensuring model-provided budget ranges are valid (positive and min <= max). New tests have been added to cover various configuration paths and validate these invariants. However, the review highlights a critical issue: user-provided encoder_cudagraph_token_budgets are not validated for positivity, which could lead to a ZeroDivisionError if a non-positive budget is supplied.

Comment on lines 78 to +80
self.token_budgets = sorted(user_budgets)
self.max_batch_size = user_max_images
min_tok = min(self.token_budgets)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

User-provided encoder_cudagraph_token_budgets are not validated for positivity. If a user provides a non-positive budget (e.g., [0, 128]), min(self.token_budgets) could be zero or negative. This can lead to self.max_batch_size being set to zero, which will cause a ZeroDivisionError later during CUDA graph capture preparation.

You should add validation to ensure all user-provided budgets are positive. A similar check is needed in the elif user_budgets: block.

            self.token_budgets = sorted(user_budgets)
            if self.token_budgets[0] <= 0:
                raise ValueError(
                    f"Invalid encoder_cudagraph_token_budgets: {user_budgets}. "
                    "All budget values must be positive."
                )
            self.max_batch_size = user_max_images
            min_tok = self.token_budgets[0]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fair point, but I feel that the better way to handle this is through pydantic (cuz ultimately this is an input validation problem), e.g., in the definition of CompilationConfig:

from pydantic.types import PositiveInt

@config
class CompilationConfig:
    ...
    encoder_cudagraph_token_budgets: list[PositiveInt]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the check in vllm/config/compilation.py.

Comment on lines +118 to +122
self.token_budgets = sorted(user_budgets)
self.max_batch_size = min(
max_budget // min_budget,
min(self.token_budgets),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the other block, user-provided encoder_cudagraph_token_budgets are not validated for positivity here. This can lead to a ZeroDivisionError if a non-positive budget is provided.

                self.token_budgets = sorted(user_budgets)
                if self.token_budgets[0] <= 0:
                    raise ValueError(
                        f"Invalid encoder_cudagraph_token_budgets: {user_budgets}. "
                        "All budget values must be positive."
                    )
                self.max_batch_size = min(
                    max_budget // min_budget,
                    self.token_budgets[0],
                )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been addressed by checking in vllm/config/compilation.py.

@wangshangsam wangshangsam added qwen Related to Qwen models performance Performance-related issues multi-modality Related to multi-modality (#4194) labels Mar 25, 2026
@github-project-automation github-project-automation Bot moved this to Backlog in Qwen3.5 Mar 25, 2026
@wangshangsam wangshangsam moved this from Backlog to In progress in Qwen3.5 Mar 25, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 26, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @b-mu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 26, 2026
@b-mu b-mu force-pushed the bmu/vit-full-cudagraph-fix-auto-infer-invariant branch from db7c2fd to 8a2c74f Compare March 27, 2026 03:07
b-mu added a commit to CentML/vllm that referenced this pull request Mar 27, 2026
Move non-positive budget validation into CompilationConfig.__post_init__
so invalid values are caught early during config parsing rather than at
runtime in EncoderCudaGraphManager. Addresses PR vllm-project#38040 review feedback.

Signed-off-by: Baorun Mu <bmu@nvidia.com>
@mergify mergify Bot removed the needs-rebase label Mar 27, 2026
@b-mu b-mu changed the title [Draft] [Fix] Invariant Check for Auto-Inferred Budgets/Max Batch Size in ViT CUDA Graph Manager [Fix] Invariant Check for Auto-Inferred Budgets/Max Batch Size in ViT CUDA Graph Manager Mar 27, 2026
@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Mar 27, 2026
@Isotr0py Isotr0py enabled auto-merge (squash) March 27, 2026 16:20
@github-actions github-actions Bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 27, 2026
b-mu added 5 commits May 5, 2026 11:18
When auto-inferring max_batch_size as max_budget // min_budget, the
result can exceed min_budget (e.g. 16384 // 64 = 256 > 64). During
graph capture, prepare_encoder_cudagraph_capture_inputs divides
token_budget by max_batch_size to get per_image_output, which yields 0
for small budgets (64 // 256 = 0), causing a reshape crash on empty
tensors.

Fix by capping max_batch_size to min(max_budget // min_budget,
min_budget), ensuring per_image_output >= 1 for all budgets.

Signed-off-by: Baorun Mu <bmu@nvidia.com>
… graphs

The previous fix (bbb31043b) only capped auto-inferred max_batch_size,
but several configuration paths still violated the invariant
max_batch_size <= min(token_budgets), causing per_image_output = 0
and reshape crashes during CUDA graph capture:

1. Fully user-specified: no validation at all
2. User provides only max_images: value used directly without cap
3. User provides only budgets: auto-inferred cap used model's
   min_budget instead of min(user_budgets)

Fix by handling each configuration path explicitly:
- User-specified both: validate and raise informative ValueError
- User max_images only: adjust budget generation to start at
  max(min_budget, user_max_images)
- User budgets only: cap max_batch_size to min(user_budgets)
- Fully auto-inferred: cap max_batch_size to min(generated_budgets)

Also validate model-returned budget range (positive, min <= max).

Signed-off-by: Baorun Mu <bmu@nvidia.com>
Move non-positive budget validation into CompilationConfig.__post_init__
so invalid values are caught early during config parsing rather than at
runtime in EncoderCudaGraphManager. Addresses PR vllm-project#38040 review feedback.

Signed-off-by: Baorun Mu <bmu@nvidia.com>
Post-rebase fallout from upstream commit 936e0b7
(`[MM][CG] Optimize default max_frames_per_batch auto-infer`):

- Mock fixtures: add `_MockMultimodalConfig.get_limit_per_prompt`
  returning 0 for "video" so the new max_frames_per_batch branch
  short-circuits before the model-side `get_max_frames_per_video`
  call (image-only mocks don't need video-frame plumbing). Also
  switch `encoder_cudagraph_max_frames_per_batch` default in the
  mock to None to match upstream's `int | None = None` field type.
- ruff-format reflow in EncoderCudaGraphManager.__init__ where the
  conflict resolution had wrapped a single-line call.

Signed-off-by: Baorun Mu <bmu@nvidia.com>
Two correctness fixes on top of the existing Qwen3-VL ViT CUDA graph
implementation, surfaced during review of vllm-project#40830 (qwen2.5-vl):

1. Use ceil (not floor) for per_mm_item_output in
   prepare_encoder_cudagraph_capture_inputs. Floor sizes the captured
   pixel_values buffer at max_batch_size * (token_budget //
   max_batch_size), which is < token_budget whenever the budget is
   not divisible by max_batch_size. Replay copy for the worst-case
   single-item batch then raises a shape mismatch. The non-divisible
   case is reachable in every config path: the manager enforces
   max_batch_size <= min(budgets) but not divisibility. Mirrors the
   fix in vllm-project#40830 for qwen2.5-vl.

2. Disable ViT CUDA graph for ALL modalities (not just video) when
   EVS pruning is enabled. embed_multimodal post-processes both image
   and video embeddings (mrope positions appended for image,
   prune+append for video) when is_multimodal_pruning_enabled is
   True. The encoder CUDA graph path bypasses that hook, producing
   inconsistent embedding formats vs eager. Matches the pattern in
   qwen2.5-vl PR vllm-project#40830:
   `modalities = [] if pruning else ["image", "video"]`.

Signed-off-by: Baorun Mu <bmu@nvidia.com>
auto-merge was automatically disabled May 6, 2026 04:28

Head branch was pushed to by a user without write access

@b-mu b-mu force-pushed the bmu/vit-full-cudagraph-fix-auto-infer-invariant branch from 8a2c74f to b3b32b4 Compare May 6, 2026 04:28
@b-mu b-mu changed the title [Fix] Invariant Check for Auto-Inferred Budgets/Max Batch Size in ViT CUDA Graph Manager [Fix] Misc Fixes in ViT CUDA Graph May 6, 2026
@b-mu
Copy link
Copy Markdown
Contributor Author

b-mu commented May 6, 2026

@Isotr0py I have rebased this branch, and added two other minor fixes for bugs found when I was reviewing #40830. Could you take another look?

Copy link
Copy Markdown
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, requested force merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

multi-modality Related to multi-modality (#4194) nvidia performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Ready
Status: In progress

Development

Successfully merging this pull request may close these issues.

3 participants