@@ -2369,14 +2369,6 @@ def capture_model(self) -> None:
23692369 # can reuse the memory pool allocated for the large shapes.
23702370 with graph_capture (device = self .device ):
23712371 full_cg = self .full_cuda_graph
2372- # for full cg on pure decode only, do not capture size lager than
2373- # max_num_seqs
2374- if full_cg and self .attn_metadata_builders [0 ].attn_cudagraph_support \
2375- == AttentionCGSupport .PURE_DECODE_ONLY :
2376- max_num_seqs = self .scheduler_config .max_num_seqs
2377- self .cudagraph_batch_sizes = [
2378- size for size in self .cudagraph_batch_sizes
2379- if size <= max_num_seqs ]
23802372
23812373 # Only rank 0 should print progress bar during capture
23822374 compilation_cases = reversed (self .cudagraph_batch_sizes )
@@ -2446,13 +2438,20 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
24462438 self .device ,
24472439 )
24482440
2449- if (self .full_cuda_graph
2450- and attn_metadata_builder_i .attn_cudagraph_support == \
2451- AttentionCGSupport .NEVER ):
2452- raise ValueError (
2453- f"Full CUDAGraph not supported for "
2454- f"{ attn_backend_i .__name__ } . Turn off CompilationConfig."
2455- f"full_cuda_graph or use a different attention backend." )
2441+ if self .full_cuda_graph :
2442+ if attn_metadata_builder_i .attn_cudagraph_support == \
2443+ AttentionCGSupport .NEVER :
2444+ raise ValueError (
2445+ f"Full CUDAGraph not supported for "
2446+ f"{ attn_backend_i .__name__ } . Turn off "
2447+ f"CompilationConfig.full_cuda_graph or use a "
2448+ f" different attention backend." )
2449+ if attn_metadata_builder_i .attn_cudagraph_support == \
2450+ AttentionCGSupport .PURE_DECODE_ONLY :
2451+ self .cudagraph_batch_sizes = [
2452+ size for size in self .cudagraph_batch_sizes
2453+ if size <= self .scheduler_config .max_num_seqs
2454+ ]
24562455
24572456 self .attn_backends .append (attn_backend_i )
24582457 self .attn_metadata_builders .append (attn_metadata_builder_i )
0 commit comments