Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions tests/v1/cudagraph/test_encoder_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,68 @@
# ---------------------------------------------------------------------------


class _MockCompilationConfig:
"""Minimal mock for VllmConfig.compilation_config."""

def __init__(
self,
token_budgets: list[int] | None = None,
max_mm_items: int = 0,
):
self.encoder_cudagraph_token_budgets = token_budgets or []
self.encoder_cudagraph_max_vision_items_per_batch = max_mm_items
self.encoder_cudagraph_max_frames_per_batch = None


class _MockMultimodalConfig:
mm_encoder_tp_mode = "replicate"

def get_limit_per_prompt(self, modality: str) -> int:
# Image-only mocks — return 0 for "video" to short-circuit the
# max_frames_per_batch branch, so tests don't need a video-frame mock.
return 0


class _MockModelConfig:
multimodal_config = _MockMultimodalConfig()


class _MockParallelConfig:
tensor_parallel_size = 1


class _MockVllmConfig:
"""Minimal mock for VllmConfig used in __init__ tests."""

def __init__(
self,
token_budgets: list[int] | None = None,
max_mm_items: int = 0,
):
self.compilation_config = _MockCompilationConfig(token_budgets, max_mm_items)
self.model_config = _MockModelConfig()
self.parallel_config = _MockParallelConfig()


class _MockModel:
"""Minimal mock implementing SupportsEncoderCudaGraph for __init__."""

def __init__(self, min_budget: int = 4, max_budget: int = 128):
self._min_budget = min_budget
self._max_budget = max_budget

def get_encoder_cudagraph_config(self) -> EncoderCudaGraphConfig:
return EncoderCudaGraphConfig(
modalities=["image"],
input_key_by_modality={"image": "pixel_values"},
buffer_keys=["dummy_buf"],
out_hidden_size=32,
)

def get_encoder_cudagraph_budget_range(self, vllm_config):
return (self._min_budget, self._max_budget)


def _make_manager_with_budgets(budgets: list[int]) -> EncoderCudaGraphManager:
"""Create a minimal EncoderCudaGraphManager with only token_budgets set.

Expand Down Expand Up @@ -760,3 +822,113 @@ def test_image_and_video_share_manager(self):
assert len(vid_result) == 1
assert img_result[0].shape == (4, _HIDDEN)
assert vid_result[0].shape == (8, _HIDDEN)


# ---------------------------------------------------------------------------
# __init__ invariant validation tests (no GPU required)
# ---------------------------------------------------------------------------


class TestInitInvariantValidation:
"""Ensure max_batch_size <= min(token_budgets) for all config paths."""

def _make_mgr(
self,
token_budgets=None,
max_mm_items=0,
min_budget=4,
max_budget=128,
):
vllm_config = _MockVllmConfig(token_budgets, max_mm_items)
model = _MockModel(min_budget, max_budget)
return EncoderCudaGraphManager(
vllm_config=vllm_config,
device=torch.device("cpu"),
dtype=torch.float32,
model=model,
)

# --- Finding 1: fully auto-inferred ---

def test_auto_inferred_invariant_holds(self):
mgr = self._make_mgr(min_budget=64, max_budget=16384)
assert mgr.max_batch_size <= min(mgr.token_budgets)

def test_auto_inferred_small_range(self):
mgr = self._make_mgr(min_budget=4, max_budget=128)
assert mgr.max_batch_size <= min(mgr.token_budgets)

# --- Finding 2: fully user-specified, bad combo ---

def test_user_specified_bad_combo_raises(self):
with pytest.raises(ValueError, match="must be <= smallest token budget"):
self._make_mgr(token_budgets=[64], max_mm_items=256)

def test_user_specified_valid_combo(self):
mgr = self._make_mgr(token_budgets=[64, 128], max_mm_items=32)
assert mgr.max_batch_size == 32
assert mgr.token_budgets == [64, 128]

def test_user_specified_exact_boundary(self):
# max_mm_items == min(budgets) is OK (per_image_output = 1)
mgr = self._make_mgr(token_budgets=[64, 128], max_mm_items=64)
assert mgr.max_batch_size == 64

# --- Finding 3: user provides only max_mm_items ---

def test_user_max_mm_items_only_adjusts_budgets(self):
# model min_budget=64, user max_mm_items=128 → budgets start at 128
mgr = self._make_mgr(max_mm_items=128, min_budget=64, max_budget=16384)
assert mgr.max_batch_size == 128
assert min(mgr.token_budgets) >= 128

def test_user_max_mm_items_smaller_than_min_budget(self):
# max_mm_items=2, model min=4 → budgets start at 4 (>= 2), OK
mgr = self._make_mgr(max_mm_items=2, min_budget=4, max_budget=128)
assert mgr.max_batch_size == 2
assert min(mgr.token_budgets) >= 2

# --- Finding 4: user provides only budgets ---

def test_user_budgets_only_caps_max_batch_size(self):
# user budgets start at 32, model min_budget=64
# without fix: max_batch_size = min(128//64, 64) = 2 → OK
# but if user budgets=[16, 64]:
# without fix: max_batch_size = min(128//4, 4) = 4 > 16? No.
# Let's use a case that triggers it:
# model min=64, max=16384 → max_budget//min_budget = 256
# user budgets=[32, 64] → min = 32
# without fix: max_batch_size = min(256, 64) = 64 > 32 → BUG
# with fix: max_batch_size = min(256, 32) = 32 → OK
mgr = self._make_mgr(token_budgets=[32, 64], min_budget=64, max_budget=16384)
assert mgr.max_batch_size <= min(mgr.token_budgets)
assert mgr.max_batch_size == 32

# --- Finding 5/6: bad model budget range ---

def test_zero_min_budget_raises(self):
with pytest.raises(ValueError, match="Both must be positive"):
self._make_mgr(min_budget=0, max_budget=128)

def test_negative_max_budget_raises(self):
with pytest.raises(ValueError, match="Both must be positive"):
self._make_mgr(min_budget=4, max_budget=-1)

def test_min_greater_than_max_raises(self):
with pytest.raises(ValueError, match="min_budget=200 > max_budget=100"):
self._make_mgr(min_budget=200, max_budget=100)

# --- Finding 7: user-provided budgets with non-positive values ---

def test_user_budgets_zero_raises(self):
"""Non-positive budgets should be caught at config validation."""
from vllm.config.compilation import CompilationConfig

with pytest.raises(ValueError, match="must be positive"):
CompilationConfig(encoder_cudagraph_token_budgets=[0, 128])

def test_user_budgets_negative_raises(self):
from vllm.config.compilation import CompilationConfig

with pytest.raises(ValueError, match="must be positive"):
CompilationConfig(encoder_cudagraph_token_budgets=[-1, 64])
8 changes: 8 additions & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,14 @@ def __post_init__(self) -> None:
"non-negative (None = auto-infer)"
)

if self.encoder_cudagraph_token_budgets and any(
b <= 0 for b in self.encoder_cudagraph_token_budgets
):
raise ValueError(
f"All encoder_cudagraph_token_budgets must be positive, "
f"got {self.encoder_cudagraph_token_budgets}"
)

if self.backend == "":
self.backend = current_platform.get_compile_backend()

Expand Down
19 changes: 10 additions & 9 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,14 +1768,12 @@ def get_encoder_cudagraph_config(self):
EncoderCudaGraphConfig,
)

modalities = ["image"]
# NOTE: When EVS (Efficient Video Sampling) pruning is enabled, the number
# of tokens becomes data-dependent (i.e., the retained tokens are
# dynamically selected based on inter-frame differences) and therefore
# cannot be captured by CUDA Graphs. As a result, video CUDA Graphs are
# only enabled when EVS is disabled.
if not self.is_multimodal_pruning_enabled:
modalities.append("video")
# When EVS pruning is enabled, embed_multimodal post-processes both
# image and video embeddings (mrope positions are appended for image,
# prune+append for video). The encoder CUDA graph path bypasses that
# post-process, producing inconsistent embedding formats vs eager. So
# disable CUDA graph for all modalities when pruning is on.
modalities = [] if self.is_multimodal_pruning_enabled else ["image", "video"]

return EncoderCudaGraphConfig(
modalities=modalities,
Expand Down Expand Up @@ -1923,7 +1921,10 @@ def prepare_encoder_cudagraph_capture_inputs(
)

spatial_merge_size = self.visual.spatial_merge_size
per_mm_item_output = token_budget // max_batch_size
# Ceil so the buffer fits the worst case of one item using the full
# budget. Floor under-allocates when budget is not a multiple of
# max_batch_size.
per_mm_item_output = (token_budget + max_batch_size - 1) // max_batch_size

frames_per_item = max_frames_per_batch // max_batch_size
if frames_per_item > 1:
Expand Down
64 changes: 52 additions & 12 deletions vllm/v1/worker/encoder_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,65 @@ def __init__(

multimodal_config = vllm_config.model_config.multimodal_config

# Invariant: max_batch_size <= min_token_budget.
# This ensures per_image_output = budget // max_batch_size >= 1
# for every captured budget, preventing reshape crashes on empty
# tensors during CUDA graph capture. Validated/enforced below for
# each configuration path.
if user_budgets and user_max_vision_items > 0:
# Fully user-specified
# Fully user-specified: validate the invariant.
self.token_budgets = sorted(user_budgets)
self.max_batch_size = user_max_vision_items
min_tok = min(self.token_budgets)
Comment on lines 82 to +84
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

User-provided encoder_cudagraph_token_budgets are not validated for positivity. If a user provides a non-positive budget (e.g., [0, 128]), min(self.token_budgets) could be zero or negative. This can lead to self.max_batch_size being set to zero, which will cause a ZeroDivisionError later during CUDA graph capture preparation.

You should add validation to ensure all user-provided budgets are positive. A similar check is needed in the elif user_budgets: block.

            self.token_budgets = sorted(user_budgets)
            if self.token_budgets[0] <= 0:
                raise ValueError(
                    f"Invalid encoder_cudagraph_token_budgets: {user_budgets}. "
                    "All budget values must be positive."
                )
            self.max_batch_size = user_max_images
            min_tok = self.token_budgets[0]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fair point, but I feel that the better way to handle this is through pydantic (cuz ultimately this is an input validation problem), e.g., in the definition of CompilationConfig:

from pydantic.types import PositiveInt

@config
class CompilationConfig:
    ...
    encoder_cudagraph_token_budgets: list[PositiveInt]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the check in vllm/config/compilation.py.

if self.max_batch_size > min_tok:
raise ValueError(
f"encoder_cudagraph_max_vision_items_per_batch "
f"({self.max_batch_size}) must be <= smallest token "
f"budget ({min_tok}). With budgets="
f"{self.token_budgets}, per_image_output = "
f"{min_tok} // {self.max_batch_size} = "
f"{min_tok // self.max_batch_size}, which would cause "
f"a capture failure. Either increase the smallest "
f"budget or decrease max_vision_items_per_batch."
)
else:
# Auto-infer missing values from model
# Auto-infer missing values from model.
min_budget, max_budget = model.get_encoder_cudagraph_budget_range(
vllm_config
)
self.token_budgets = (
sorted(user_budgets)
if user_budgets
else self._generate_budgets(min_budget, max_budget)
)
self.max_batch_size = (
user_max_vision_items
if user_max_vision_items > 0
else max_budget // min_budget
)
if min_budget <= 0 or max_budget <= 0:
raise ValueError(
f"Invalid encoder cudagraph budget range: "
f"min_budget={min_budget}, max_budget={max_budget}. "
f"Both must be positive."
)
if min_budget > max_budget:
raise ValueError(
f"Invalid encoder cudagraph budget range: "
f"min_budget={min_budget} > max_budget={max_budget}."
)

if user_max_vision_items > 0:
# User provided max_vision_items only; adjust auto-inferred
# budgets so min(budgets) >= max_batch_size.
self.max_batch_size = user_max_vision_items
effective_min = max(min_budget, user_max_vision_items)
self.token_budgets = self._generate_budgets(effective_min, max_budget)
elif user_budgets:
# User provided budgets only; cap auto-inferred
# max_batch_size to min(user_budgets).
self.token_budgets = sorted(user_budgets)
self.max_batch_size = min(
max_budget // min_budget,
min(self.token_budgets),
)
Comment on lines +122 to +126
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the other block, user-provided encoder_cudagraph_token_budgets are not validated for positivity here. This can lead to a ZeroDivisionError if a non-positive budget is provided.

                self.token_budgets = sorted(user_budgets)
                if self.token_budgets[0] <= 0:
                    raise ValueError(
                        f"Invalid encoder_cudagraph_token_budgets: {user_budgets}. "
                        "All budget values must be positive."
                    )
                self.max_batch_size = min(
                    max_budget // min_budget,
                    self.token_budgets[0],
                )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been addressed by checking in vllm/config/compilation.py.

else:
# Fully auto-inferred.
self.token_budgets = self._generate_budgets(min_budget, max_budget)
self.max_batch_size = min(
max_budget // min_budget,
min(self.token_budgets),
)

assert multimodal_config is not None
if multimodal_config.get_limit_per_prompt("video") == 0:
Expand Down
Loading