From d80a82f961e2b3570f470171b9683d5e5ad06834 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 27 Jul 2025 21:06:56 +0000 Subject: [PATCH 01/11] fix dp plus full cuda-graph Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashmla.py | 54 +++++++++++++--------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index d3e5300dbbd6..51e77bc6bfa4 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -67,6 +67,20 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None + device_properties = torch.cuda.get_device_properties(self.device) + num_sms = device_properties.multi_processor_count + + if self.compilation_config.full_cuda_graph: + self.cg_buf_tile_scheduler_metadata = torch.empty( + (num_sms, 8), # TileSchedulerMetaDataSize == 8 + device=self.device, + dtype=torch.int32, + ) + self.cg_buf_num_splits = torch.empty( + (vllm_config.scheduler_config.max_num_seqs + 1), + device=self.device, + dtype=torch.int32) + def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ @@ -77,28 +91,24 @@ def _build_decode(self, block_table_tensor: torch.Tensor, ) if self.compilation_config.full_cuda_graph: - # First time around (CUDAGraph capture), allocate the static buffer - if self.cg_buf_tile_scheduler_metadata is None: - self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata - self.cg_buf_num_splits = num_splits - else: - assert self.cg_buf_num_splits is not None - - # Metadata per-SM, fixed size (#SMs, TileMetadataSize) - assert (self.cg_buf_tile_scheduler_metadata.size() == - tile_scheduler_metadata.size()) - self.cg_buf_tile_scheduler_metadata.\ - copy_(tile_scheduler_metadata) - tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata - - # Num splits is per-batch, varying size (batch_size,) - n = num_splits.size(0) - # make sure static buffer is large enough - assert n <= self.cg_buf_num_splits.size(0) - num_splits_view = self.cg_buf_num_splits[:n] - num_splits_view.copy_(num_splits) - self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s - num_splits = num_splits_view + assert self.cg_buf_tile_scheduler_metadata is not None + assert self.cg_buf_num_splits is not None + + # Metadata per-SM, fixed size (#SMs, TileMetadataSize) + assert (self.cg_buf_tile_scheduler_metadata.size() == + tile_scheduler_metadata.size()) + self.cg_buf_tile_scheduler_metadata.\ + copy_(tile_scheduler_metadata) + tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata + + # Num splits is per-batch, varying size (batch_size,) + n = num_splits.size(0) + # make sure static buffer is large enough + assert n <= self.cg_buf_num_splits.size(0) + num_splits_view = self.cg_buf_num_splits[:n] + num_splits_view.copy_(num_splits) + self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s + num_splits = num_splits_view return FlashMLADecodeMetadata( block_table=block_table_tensor, From 073c4c267af483d2d79db5aac5b123a5271fbbb8 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 28 Jul 2025 01:10:03 +0000 Subject: [PATCH 02/11] fix Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashmla.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 51e77bc6bfa4..dfca16169cc3 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -72,7 +72,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, if self.compilation_config.full_cuda_graph: self.cg_buf_tile_scheduler_metadata = torch.empty( - (num_sms, 8), # TileSchedulerMetaDataSize == 8 + # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize) + # TileSchedulerMetaDataSize = 8 + (num_sms, 8), device=self.device, dtype=torch.int32, ) @@ -94,10 +96,11 @@ def _build_decode(self, block_table_tensor: torch.Tensor, assert self.cg_buf_tile_scheduler_metadata is not None assert self.cg_buf_num_splits is not None - # Metadata per-SM, fixed size (#SMs, TileMetadataSize) - assert (self.cg_buf_tile_scheduler_metadata.size() == - tile_scheduler_metadata.size()) - self.cg_buf_tile_scheduler_metadata.\ + # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) + assert (self.cg_buf_tile_scheduler_metadata.size(0) + >= tile_scheduler_metadata.size(0)) + sm_parts = tile_scheduler_metadata.size(0) + self.cg_buf_tile_scheduler_metadata[:sm_parts].\ copy_(tile_scheduler_metadata) tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata From 1b2085be5cf621148483c974aedcc3b959cabc4f Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 28 Jul 2025 23:51:39 +0000 Subject: [PATCH 03/11] fix sizing Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashmla.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index dfca16169cc3..b02e22c95316 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -102,7 +102,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor, sm_parts = tile_scheduler_metadata.size(0) self.cg_buf_tile_scheduler_metadata[:sm_parts].\ copy_(tile_scheduler_metadata) - tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata + tile_scheduler_metadata = \ + self.cg_buf_tile_scheduler_metadata[:sm_parts] # Num splits is per-batch, varying size (batch_size,) n = num_splits.size(0) From dbe6833eb3c3db71e2c834064f2c3b7a01bbc5b3 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 30 Jul 2025 00:32:32 +0000 Subject: [PATCH 04/11] fix ima Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashmla.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index b02e22c95316..9a82660138c0 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -111,7 +111,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor, assert n <= self.cg_buf_num_splits.size(0) num_splits_view = self.cg_buf_num_splits[:n] num_splits_view.copy_(num_splits) - self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s + # Num splits needs to monotonically increasing + self.cg_buf_num_splits[n:].fill_(num_splits[-1]) num_splits = num_splits_view return FlashMLADecodeMetadata( From 3f27c704411c251a01b2fe5a8c4ffd58df3bf7f4 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 30 Jul 2025 01:20:35 +0000 Subject: [PATCH 05/11] another attempt Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashmla.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 9a82660138c0..62280479aae6 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -71,7 +71,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, num_sms = device_properties.multi_processor_count if self.compilation_config.full_cuda_graph: - self.cg_buf_tile_scheduler_metadata = torch.empty( + self.cg_buf_tile_scheduler_metadata = torch.zeros( # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize) # TileSchedulerMetaDataSize = 8 (num_sms, 8), @@ -96,14 +96,13 @@ def _build_decode(self, block_table_tensor: torch.Tensor, assert self.cg_buf_tile_scheduler_metadata is not None assert self.cg_buf_num_splits is not None - # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) - assert (self.cg_buf_tile_scheduler_metadata.size(0) - >= tile_scheduler_metadata.size(0)) sm_parts = tile_scheduler_metadata.size(0) - self.cg_buf_tile_scheduler_metadata[:sm_parts].\ - copy_(tile_scheduler_metadata) - tile_scheduler_metadata = \ + # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) + assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0) + tile_scheduler_metadata_view = \ self.cg_buf_tile_scheduler_metadata[:sm_parts] + tile_scheduler_metadata_view.copy_(tile_scheduler_metadata) + tile_scheduler_metadata = tile_scheduler_metadata_view # Num splits is per-batch, varying size (batch_size,) n = num_splits.size(0) From 61196154a3e6287de201fe7f54db934b79290f6d Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 31 Jul 2025 21:18:30 +0000 Subject: [PATCH 06/11] another attempt Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashmla.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 62280479aae6..66e7e275bf1b 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -83,6 +83,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device=self.device, dtype=torch.int32) + self.arange_req_cache = torch.arange( + vllm_config.scheduler_config.max_num_seqs, + device=self.device, + dtype=torch.int32, + ) + def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ @@ -110,8 +116,13 @@ def _build_decode(self, block_table_tensor: torch.Tensor, assert n <= self.cg_buf_num_splits.size(0) num_splits_view = self.cg_buf_num_splits[:n] num_splits_view.copy_(num_splits) - # Num splits needs to monotonically increasing - self.cg_buf_num_splits[n:].fill_(num_splits[-1]) + # Num splits needs to monotonically increasing by 1 for the + # padded elements of the batch so that combine kernel returns + # out immediately (which happens when num_splits == 1 + # not num_splits <= 1 unforunately) + buffer_padding_size = self.cg_buf_num_splits.size(0) - n + self.cg_buf_num_splits[n:] = \ + num_splits[-1] + 1 + self.arange_req_cache[:buffer_padding_size] num_splits = num_splits_view return FlashMLADecodeMetadata( From 1035a64b8ff37213b09e4f2399a93ddd552a3988 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 4 Aug 2025 19:57:59 +0000 Subject: [PATCH 07/11] another attempt Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 9 +++++---- vllm/attention/ops/flashmla.py | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 6291475164ba..4f645dd3626d 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845 + GIT_TAG 75cb7a754bce725ed38b7e6b74e98dcd8e952e20 GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -27,6 +27,7 @@ else() endif() +list(APPEND VLLM_GPU_FLAGS "-lineinfo") FetchContent_MakeAvailable(flashmla) message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") @@ -37,9 +38,9 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) set(FlashMLA_SOURCES ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp - ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu - ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu - ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu) + ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu + ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu) set(FlashMLA_INCLUDES ${flashmla_SOURCE_DIR}/csrc/cutlass/include diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index b85f27ac417c..1af26dfc3daa 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -91,7 +91,6 @@ def flash_mla_with_kvcache( out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( q, k_cache, - None, head_dim_v, cache_seqlens, block_table, From 0f845f9a4f40e7649d0bbb11a3739208e4ab0155 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 5 Aug 2025 02:00:30 +0000 Subject: [PATCH 08/11] cleaner Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 2 +- vllm/v1/attention/backends/mla/flashmla.py | 17 ++++------------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 4f645dd3626d..afe96ba548cf 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG 75cb7a754bce725ed38b7e6b74e98dcd8e952e20 + GIT_TAG 3edc935ad6368dabdcd9b4223c2dcbcfa983e3f8 GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 66e7e275bf1b..03941e1dd20e 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -83,12 +83,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device=self.device, dtype=torch.int32) - self.arange_req_cache = torch.arange( - vllm_config.scheduler_config.max_num_seqs, - device=self.device, - dtype=torch.int32, - ) - def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ @@ -116,13 +110,10 @@ def _build_decode(self, block_table_tensor: torch.Tensor, assert n <= self.cg_buf_num_splits.size(0) num_splits_view = self.cg_buf_num_splits[:n] num_splits_view.copy_(num_splits) - # Num splits needs to monotonically increasing by 1 for the - # padded elements of the batch so that combine kernel returns - # out immediately (which happens when num_splits == 1 - # not num_splits <= 1 unforunately) - buffer_padding_size = self.cg_buf_num_splits.size(0) - n - self.cg_buf_num_splits[n:] = \ - num_splits[-1] + 1 + self.arange_req_cache[:buffer_padding_size] + # Num splits needs to monotonically increasing + # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise + # it needs to montonicaly increasing by 1) + self.cg_buf_num_splits[n:].fill_(num_splits[-1]) num_splits = num_splits_view return FlashMLADecodeMetadata( From d8844c66c5a64798452a461448a61d2715e9c782 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 5 Aug 2025 03:50:12 +0000 Subject: [PATCH 09/11] cleanup Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 1 - 1 file changed, 1 deletion(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index afe96ba548cf..157216793056 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -27,7 +27,6 @@ else() endif() -list(APPEND VLLM_GPU_FLAGS "-lineinfo") FetchContent_MakeAvailable(flashmla) message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") From abebc37216ad28f6867aabcee66a35692d954c1c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 5 Aug 2025 13:52:43 +0000 Subject: [PATCH 10/11] update SHA Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 157216793056..ee6768bce26c 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG 3edc935ad6368dabdcd9b4223c2dcbcfa983e3f8 + GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1 GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" From b1b448283b166395be175d3d5bceda121da95808 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 5 Aug 2025 13:53:26 +0000 Subject: [PATCH 11/11] spelling fix Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashmla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 03941e1dd20e..cd4e95de0071 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -112,7 +112,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, num_splits_view.copy_(num_splits) # Num splits needs to monotonically increasing # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise - # it needs to montonicaly increasing by 1) + # it needs to monotonically increasing by 1) self.cg_buf_num_splits[n:].fill_(num_splits[-1]) num_splits = num_splits_view