Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/compile/test_fusions_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def test_attn_quant(
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
# FIXME: Llama-4-Scout-17B-16E-Instruct-FP8 + FlashInfer + Blackwell end at
# CUDAGraphMode.NONE here because it derives an attention backend that
# does not support full cudagraphs
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []

Expand Down
48 changes: 30 additions & 18 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3751,8 +3751,6 @@ def capture_model(self) -> int:
"ensure `cudagraph_mode` was not manually set to `NONE`"
)
return 0
else:
self.initialize_cudagraph_capture()

compilation_counter.num_gpu_runner_capture_triggers += 1

Expand Down Expand Up @@ -3926,7 +3924,7 @@ class AttentionGroupKey(NamedTuple):

def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec,
) -> dict[AttentionGroupKey, list[str]]:
) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]:
layers = get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names
)
Expand Down Expand Up @@ -3955,7 +3953,10 @@ def get_attn_backends_for_group(
attn_backend, layer_kv_cache_spec
)
attn_backend_layers[key].append(layer_name)
return {attn_backends[k]: v for k, v in attn_backend_layers.items()}
return (
{attn_backends[k]: v for k, v in attn_backend_layers.items()},
set(group_key.attn_backend for group_key in attn_backends.values()),
)

def create_attn_groups(
attn_backends_map: dict[AttentionGroupKey, list[str]],
Expand All @@ -3976,28 +3977,39 @@ def create_attn_groups(
attn_groups.append(attn_group)
return attn_groups

attention_backend_maps = []
attention_backend_set: set[type[AttentionBackend]] = set()
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
self.attn_groups.append(create_attn_groups(attn_backends))
attention_backend_maps.append(attn_backends[0])
attention_backend_set.update(attn_backends[1])

# Resolve cudagraph_mode before actually initialize metadata_builders
self._check_and_update_cudagraph_mode(attention_backend_set)

for attn_backends_map in attention_backend_maps:
self.attn_groups.append(create_attn_groups(attn_backends_map))

# Calculate reorder batch threshold (if needed)
self.calculate_reorder_batch_threshold()

def initialize_cudagraph_capture(self) -> None:
def _check_and_update_cudagraph_mode(
self, attention_backends: set[type[AttentionBackend]]
) -> None:
"""
Resolve the cudagraph_mode when there are multiple attention
backends with potential conflicting CUDA graph support.
Then initialize the cudagraph_dispatcher based on the resolved
cudagraph_mode.
"""
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_builder_name = None
min_cg_backend_name = None

for attn_group in self._attn_group_iterator():
builder = attn_group.get_metadata_builder()
if builder.cudagraph_support.value < min_cg_support.value:
min_cg_support = builder.cudagraph_support
min_cg_builder_name = builder.__class__.__name__
for attn_backend in attention_backends:
builder_cls = attn_backend.get_builder_cls()
if builder_cls.cudagraph_support.value < min_cg_support.value:
min_cg_support = builder_cls.cudagraph_support
min_cg_backend_name = attn_backend.__name__
# Flexible resolve the cudagraph mode
cudagraph_mode = self.compilation_config.cudagraph_mode
# check cudagraph for mixed batch is supported
Expand All @@ -4007,7 +4019,7 @@ def initialize_cudagraph_capture(self) -> None:
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_builder_name} backend (support: "
f"with {min_cg_backend_name} backend (support: "
f"{min_cg_support})"
)
if min_cg_support == AttentionCGSupport.NEVER:
Expand Down Expand Up @@ -4038,7 +4050,7 @@ def initialize_cudagraph_capture(self) -> None:
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_builder_name} backend (support: "
f"with {min_cg_backend_name} backend (support: "
f"{min_cg_support})"
)
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and (
Expand Down Expand Up @@ -4072,7 +4084,7 @@ def initialize_cudagraph_capture(self) -> None:
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
f" with spec-decode for attention backend "
f"{min_cg_builder_name} (support: {min_cg_support})"
f"{min_cg_backend_name} (support: {min_cg_support})"
)
if self.compilation_config.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=PIECEWISE"
Expand All @@ -4094,14 +4106,14 @@ def initialize_cudagraph_capture(self) -> None:
):
raise ValueError(
f"CUDAGraphMode.{cudagraph_mode.name} is not "
f"supported with {min_cg_builder_name} backend ("
f"supported with {min_cg_backend_name} backend ("
f"support:{min_cg_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation mode is VLLM_COMPILE"
)

# Trigger cudagraph dispatching keys initialization here (after
# initializing attn backends).
# Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode.
self.cudagraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode, self.uniform_decode_query_len
)
Expand Down
Loading