|
18 | 18 | from vllm.attention import AttentionType, get_attn_backend |
19 | 19 | from vllm.attention.backends.abstract import AttentionBackend |
20 | 20 | from vllm.attention.layer import Attention |
21 | | -from vllm.attention.utils.fa_utils import get_flash_attn_version |
22 | 21 | from vllm.config import (CompilationLevel, VllmConfig, |
23 | 22 | get_layers_from_vllm_config) |
24 | 23 | from vllm.distributed.kv_transfer import (get_kv_transfer_group, |
@@ -555,7 +554,15 @@ def _get_cumsum_and_arange( |
555 | 554 | def _prepare_inputs( |
556 | 555 | self, |
557 | 556 | scheduler_output: "SchedulerOutput", |
558 | | - ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]: |
| 557 | + ) -> tuple[dict[str, Any], bool, torch.Tensor, |
| 558 | + Optional[SpecDecodeMetadata]]: |
| 559 | + """ |
| 560 | + :return: tuple[ |
| 561 | + attn_metadata: layer-to-attention_metadata mapping, |
| 562 | + attention_cuda_graphs: whether attention can run in captured cudagraph |
| 563 | + logits_indices, spec_decode_metadata |
| 564 | + ] |
| 565 | + """ |
559 | 566 | total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens |
560 | 567 | assert total_num_scheduled_tokens > 0 |
561 | 568 | num_reqs = self.input_batch.num_reqs |
@@ -677,27 +684,31 @@ def _prepare_inputs( |
677 | 684 | ) |
678 | 685 |
|
679 | 686 | attn_metadata: dict[str, Any] = {} |
| 687 | + attention_cuda_graphs = [] |
680 | 688 | # Prepare the attention metadata for each KV cache group and make layers |
681 | 689 | # in the same group share the same metadata. |
682 | 690 | for kv_cache_group_id, kv_cache_group_spec in enumerate( |
683 | 691 | self.kv_cache_config.kv_cache_groups): |
684 | 692 |
|
685 | 693 | # Prepare for cascade attention if enabled & beneficial. |
686 | 694 | common_prefix_len = 0 |
| 695 | + builder = self.attn_metadata_builders[kv_cache_group_id] |
687 | 696 | if self.cascade_attn_enabled: |
688 | 697 | common_prefix_len = self._compute_cascade_attn_prefix_len( |
689 | 698 | num_scheduled_tokens, |
690 | 699 | scheduler_output. |
691 | 700 | num_common_prefix_blocks[kv_cache_group_id], |
692 | 701 | kv_cache_group_spec.kv_cache_spec, |
693 | | - self.attn_metadata_builders[kv_cache_group_id], |
| 702 | + builder, |
694 | 703 | ) |
695 | 704 |
|
696 | | - attn_metadata_i = ( |
697 | | - self.attn_metadata_builders[kv_cache_group_id].build( |
698 | | - common_prefix_len=common_prefix_len, |
699 | | - common_attn_metadata=common_attn_metadata, |
700 | | - )) |
| 705 | + attn_metadata_i = (builder.build( |
| 706 | + common_prefix_len=common_prefix_len, |
| 707 | + common_attn_metadata=common_attn_metadata, |
| 708 | + )) |
| 709 | + attention_cuda_graphs.append( |
| 710 | + builder.can_run_in_cudagraph(common_attn_metadata)) |
| 711 | + |
701 | 712 | for layer_name in kv_cache_group_spec.layer_names: |
702 | 713 | attn_metadata[layer_name] = attn_metadata_i |
703 | 714 |
|
@@ -729,7 +740,8 @@ def _prepare_inputs( |
729 | 740 | if self.lora_config: |
730 | 741 | self.set_active_loras(self.input_batch, num_scheduled_tokens) |
731 | 742 |
|
732 | | - return attn_metadata, logits_indices, spec_decode_metadata |
| 743 | + return attn_metadata, all( |
| 744 | + attention_cuda_graphs), logits_indices, spec_decode_metadata |
733 | 745 |
|
734 | 746 | def _compute_cascade_attn_prefix_len( |
735 | 747 | self, |
@@ -1189,8 +1201,8 @@ def execute_model( |
1189 | 1201 | return self.kv_connector_no_forward(scheduler_output) |
1190 | 1202 |
|
1191 | 1203 | # Prepare the decoder inputs. |
1192 | | - attn_metadata, logits_indices, spec_decode_metadata = ( |
1193 | | - self._prepare_inputs(scheduler_output)) |
| 1204 | + (attn_metadata, attention_cuda_graphs, logits_indices, |
| 1205 | + spec_decode_metadata) = (self._prepare_inputs(scheduler_output)) |
1194 | 1206 | num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens |
1195 | 1207 | if (self.use_cuda_graph |
1196 | 1208 | and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): |
@@ -1255,11 +1267,9 @@ def execute_model( |
1255 | 1267 | intermediate_tensors = self.sync_and_slice_intermediate_tensors( |
1256 | 1268 | num_input_tokens, intermediate_tensors, True) |
1257 | 1269 |
|
1258 | | - # Some attention backends only support CUDA graphs in pure decode. |
1259 | | - # Assume cuda_graph_supported is false if it does not exist. |
1260 | | - attention_cuda_graphs = all( |
1261 | | - getattr(m, "cuda_graph_supported", False) |
1262 | | - for _, m in attn_metadata.items()) |
| 1270 | + # Some attention backends only support CUDA Graphs in pure decode. |
| 1271 | + # If attention doesn't support CUDA Graphs for this batch, but we |
| 1272 | + # compiled with full CUDA graphs, we have to skip them entirely. |
1263 | 1273 | skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs |
1264 | 1274 |
|
1265 | 1275 | # Run the decoder. |
@@ -2100,20 +2110,20 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: |
2100 | 2110 | "Non-Attention backend is not supported by V1 " |
2101 | 2111 | "GPUModelRunner.") |
2102 | 2112 |
|
2103 | | - if self.compilation_config.full_cuda_graph: |
2104 | | - attn_backend_name = attn_backend_i.__name__ |
2105 | | - flash_attn_version = get_flash_attn_version() |
2106 | | - if ((attn_backend_name != "FlashAttentionBackend" |
2107 | | - or flash_attn_version != 3) |
2108 | | - and attn_backend_name != "FlashMLABackend"): |
2109 | | - raise ValueError( |
2110 | | - f"Full CUDAGraph is only supported with FA3 or FlashMLA" |
2111 | | - f". Current attention backend is {attn_backend_name}, " |
2112 | | - f"FlashAttention version is {flash_attn_version}.") |
2113 | | - |
2114 | 2113 | block_table_i = self.input_batch.block_table[i] |
2115 | 2114 | attn_metadata_builder_i = attn_backend_i.get_builder_cls()( |
2116 | | - weakref.proxy(self), kv_cache_spec, block_table_i) |
| 2115 | + weakref.proxy(self), |
| 2116 | + kv_cache_spec, |
| 2117 | + block_table_i, |
| 2118 | + ) |
| 2119 | + |
| 2120 | + if (self.full_cuda_graph |
| 2121 | + and not attn_metadata_builder_i.full_cudagraph_supported): |
| 2122 | + raise ValueError( |
| 2123 | + f"Full CUDAGraph not supported for " |
| 2124 | + f"{attn_backend_i.__name__}. Turn off CompilationConfig." |
| 2125 | + f"full_cuda_graph or use a different attention backend.") |
| 2126 | + |
2117 | 2127 | self.attn_backends.append(attn_backend_i) |
2118 | 2128 | self.attn_metadata_builders.append(attn_metadata_builder_i) |
2119 | 2129 |
|
|
0 commit comments