Skip to content

Commit 1928556

Browse files
committed
fix IMA
Signed-off-by: fhl2000 <[email protected]>
1 parent 9db6e4d commit 1928556

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

vllm/v1/attention/backends/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class CommonAttentionMetadata:
6363

6464
M = TypeVar("M")
6565

66+
6667
class AttentionCGSupport(enum.Enum):
6768
""" Constants for the cudagraph support of the attention backend
6869
Here we do not consider the cascade attention, as currently
@@ -76,6 +77,7 @@ class AttentionCGSupport(enum.Enum):
7677
ALWAYS = 2
7778
"""Cudagraph always supported"""
7879

80+
7981
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
8082
# Does this backend/builder support CUDA Graphs for attention.
8183
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \

vllm/v1/worker/gpu_model_runner.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2369,14 +2369,6 @@ def capture_model(self) -> None:
23692369
# can reuse the memory pool allocated for the large shapes.
23702370
with graph_capture(device=self.device):
23712371
full_cg = self.full_cuda_graph
2372-
# for full cg on pure decode only, do not capture size lager than
2373-
# max_num_seqs
2374-
if full_cg and self.attn_metadata_builders[0].attn_cudagraph_support\
2375-
== AttentionCGSupport.PURE_DECODE_ONLY:
2376-
max_num_seqs = self.scheduler_config.max_num_seqs
2377-
self.cudagraph_batch_sizes = [
2378-
size for size in self.cudagraph_batch_sizes
2379-
if size <= max_num_seqs]
23802372

23812373
# Only rank 0 should print progress bar during capture
23822374
compilation_cases = reversed(self.cudagraph_batch_sizes)
@@ -2446,13 +2438,20 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
24462438
self.device,
24472439
)
24482440

2449-
if (self.full_cuda_graph
2450-
and attn_metadata_builder_i.attn_cudagraph_support == \
2451-
AttentionCGSupport.NEVER):
2452-
raise ValueError(
2453-
f"Full CUDAGraph not supported for "
2454-
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
2455-
f"full_cuda_graph or use a different attention backend.")
2441+
if self.full_cuda_graph:
2442+
if attn_metadata_builder_i.attn_cudagraph_support == \
2443+
AttentionCGSupport.NEVER:
2444+
raise ValueError(
2445+
f"Full CUDAGraph not supported for "
2446+
f"{attn_backend_i.__name__}. Turn off "
2447+
f"CompilationConfig.full_cuda_graph or use a "
2448+
f" different attention backend.")
2449+
if attn_metadata_builder_i.attn_cudagraph_support == \
2450+
AttentionCGSupport.PURE_DECODE_ONLY:
2451+
self.cudagraph_batch_sizes = [
2452+
size for size in self.cudagraph_batch_sizes
2453+
if size <= self.scheduler_config.max_num_seqs
2454+
]
24562455

24572456
self.attn_backends.append(attn_backend_i)
24582457
self.attn_metadata_builders.append(attn_metadata_builder_i)

vllm/v1/worker/gpu_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,12 @@ def compile_or_warm_up_model(self) -> None:
292292
self.scheduler_config.max_num_batched_tokens)
293293

294294
# We skip EPLB here since we don't want to record dummy metrics
295+
# Always activate creating attn_cudagraphs for dummy run to avoid
296+
# illegal memory access for full cudagraph.
295297
hidden_states, last_hidden_states = \
296298
self.model_runner._dummy_run(
297299
num_tokens=max_num_reqs,
300+
capture_attn_cudagraph=True,
298301
skip_eplb=True,
299302
)
300303
if self.model_runner.is_pooling_model:

0 commit comments

Comments
 (0)