diff --git a/csrc/cache.h b/csrc/cache.h index cbe44c09eb62..f2a5ec0acf5c 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include @@ -59,15 +58,6 @@ void cp_gather_cache( torch::Tensor const& cu_seq_lens, // [BATCH+1] int64_t batch_size, std::optional seq_starts = std::nullopt); -// Gather and upconvert FP8 KV cache to BF16 workspace -void cp_gather_and_upconvert_fp8_kv_cache( - torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] - torch::Tensor const& dst, // [TOT_TOKENS, 576] - torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] - torch::Tensor const& seq_lens, // [BATCH] - torch::Tensor const& workspace_starts, // [BATCH] - int64_t batch_size); - // Indexer K quantization and cache function void indexer_k_quant_and_cache( torch::Tensor& k, // [num_tokens, head_dim] @@ -82,4 +72,4 @@ void cp_gather_indexer_k_quant_cache( torch::Tensor& dst_k, // [num_tokens, head_dim] torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] const torch::Tensor& block_table, // [batch_size, num_blocks] - const torch::Tensor& cu_seq_lens); // [batch_size + 1] + const torch::Tensor& cu_seq_lens); // [batch_size + 1] \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index f11c5f24c12e..8a5457206c70 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -2,7 +2,6 @@ #include #include #include -#include #include "cuda_utils.h" #include "cuda_compat.h" @@ -515,8 +514,7 @@ __global__ void indexer_k_quant_and_cache_kernel( const int quant_block_size, // quantization block size const int cache_block_size, // cache block size const int cache_stride, // stride for each token in kv_cache - - const bool use_ue8m0 // use ue8m0 scale format + const bool use_ue8m0 // use ue8m0 scale format ) { constexpr int VEC_SIZE = 4; const int64_t token_idx = blockIdx.x; @@ -1063,82 +1061,6 @@ void gather_and_maybe_dequant_cache( } namespace vllm { - -// Gather and upconvert FP8 KV cache tokens to BF16 workspace -// Similar to cp_gather_cache but specifically for FP8->BF16 conversion -__global__ void cp_gather_and_upconvert_fp8_kv_cache( - const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] - __nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576] - const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] - const int32_t* __restrict__ seq_lens, // [BATCH] - const int32_t* __restrict__ workspace_starts, // [BATCH] - const int32_t block_size, const int32_t head_dim, - const int64_t block_table_stride, const int64_t cache_block_stride, - const int64_t cache_entry_stride, const int64_t dst_entry_stride) { - const int64_t bid = blockIdx.x; // Batch ID - const int32_t num_splits = gridDim.y; - const int32_t split = blockIdx.y; - const int32_t seq_start = workspace_starts[bid]; - const int32_t seq_len = seq_lens[bid]; - const int32_t tot_slots = seq_len; - const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); - - const int32_t split_start = split * split_slots; - const int32_t split_end = min((split + 1) * split_slots, tot_slots); - - const bool is_active_split = (split_start < tot_slots); - - if (!is_active_split) return; - - // Adjust the pointer for the block_table for this batch - const int32_t batch_offset = bid * block_table_stride; - int32_t offset = split_start; - int32_t offset_div = offset / block_size; - offset = offset % block_size; - const int32_t* batch_block_table = block_table + batch_offset; - - // Adjust dst pointer based on the cumulative sequence lengths - dst += seq_start * dst_entry_stride; - - const int tid = threadIdx.x; - - // Process each token in this split - for (int pid = split_start; pid < split_end; ++pid) { - auto block_id = batch_block_table[offset_div]; - const uint8_t* token_ptr = - src_cache + block_id * cache_block_stride + offset * cache_entry_stride; - __nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride; - - // FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16) - const uint8_t* no_pe_ptr = token_ptr; - const float* scales_ptr = reinterpret_cast(token_ptr + 512); - const __nv_bfloat16* rope_ptr = - reinterpret_cast(token_ptr + 512 + 16); - - // Parallelize fp8 dequant (512 elements) and rope copy (64 elements) - if (tid < 512) { - // FP8 dequantization - const int tile = tid >> 7; // each tile is 128 elements - const float scale = scales_ptr[tile]; - const uint8_t val = no_pe_ptr[tid]; - dst_ptr[tid] = - fp8::scaled_convert<__nv_bfloat16, uint8_t, - vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale); - } else if (tid < 576) { - // Rope copy (64 bf16 elements) - const int rope_idx = tid - 512; - dst_ptr[512 + rope_idx] = rope_ptr[rope_idx]; - } - - // Move to next token - offset += 1; - if (offset == block_size) { - offset_div += 1; - offset = 0; - } - } -} - template // Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by // block_size. @@ -1280,57 +1202,6 @@ void cp_gather_cache( } } -void cp_gather_and_upconvert_fp8_kv_cache( - torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] - torch::Tensor const& dst, // [TOT_TOKENS, 576] - torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] - torch::Tensor const& seq_lens, // [BATCH] - torch::Tensor const& workspace_starts, // [BATCH] - int64_t batch_size) { - at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - int32_t block_size = src_cache.size(1); - int32_t head_dim = dst.size(1); - - TORCH_CHECK(block_table.dtype() == torch::kInt32, - "block_table must be int32"); - TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32"); - TORCH_CHECK(workspace_starts.dtype() == torch::kInt32, - "workspace_starts must be int32"); - - TORCH_CHECK(src_cache.device() == dst.device(), - "src_cache and dst must be on the same device"); - TORCH_CHECK(src_cache.device() == block_table.device(), - "src_cache and block_table must be on the same device"); - TORCH_CHECK(src_cache.device() == seq_lens.device(), - "src_cache and seq_lens must be on the same device"); - TORCH_CHECK(src_cache.device() == workspace_starts.device(), - "src_cache and workspace_starts must be on the same device"); - - TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8"); - TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16"); - TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA"); - - int64_t block_table_stride = block_table.stride(0); - int64_t cache_block_stride = src_cache.stride(0); - int64_t cache_entry_stride = src_cache.stride(1); - int64_t dst_entry_stride = dst.stride(0); - - // Decide on the number of splits based on the batch size - int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; - dim3 grid(batch_size, num_splits); - dim3 block(576); - - vllm::cp_gather_and_upconvert_fp8_kv_cache<<>>( - src_cache.data_ptr(), - reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()), - block_table.data_ptr(), seq_lens.data_ptr(), - workspace_starts.data_ptr(), block_size, head_dim, - block_table_stride, cache_block_stride, cache_entry_stride, - dst_entry_stride); -} - // Macro to dispatch the kernel based on the data type. #define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ vllm::indexer_k_quant_and_cache_kernel \ diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 83d4943d6277..d4c6f8c67c51 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -754,13 +754,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); - cache_ops.def( - "cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, " - "Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int " - "batch_size) -> ()"); - cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA, - &cp_gather_and_upconvert_fp8_kv_cache); - cache_ops.def( "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor " "slot_mapping, " diff --git a/tests/conftest.py b/tests/conftest.py index a03f40a9a72a..82cd0ada65d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -202,27 +202,6 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): cleanup_dist_env_and_memory() -@pytest.fixture -def workspace_init(): - """Initialize the workspace manager for tests that need it. - - This fixture initializes the workspace manager with a CUDA device - if available, and resets it after the test completes. Tests that - create a full vLLM engine should NOT use this fixture as the engine - will initialize the workspace manager itself. - """ - from vllm.v1.worker.workspace import ( - init_workspace_manager, - reset_workspace_manager, - ) - - if torch.cuda.is_available(): - device = torch.device("cuda:0") - init_workspace_manager(device) - yield - reset_workspace_manager() - - @pytest.fixture(autouse=True) def dynamo_reset(): yield diff --git a/tests/kernels/moe/test_batched_deepgemm.py b/tests/kernels/moe/test_batched_deepgemm.py index 0ba3d8d4c958..59cecd60d3d6 100644 --- a/tests/kernels/moe/test_batched_deepgemm.py +++ b/tests/kernels/moe/test_batched_deepgemm.py @@ -27,7 +27,7 @@ @pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert @pytest.mark.parametrize("topk", [2, 4]) def test_batched_deepgemm_vs_triton( - E: int, T: int, K: int, N: int, topk: int, monkeypatch, workspace_init + E: int, T: int, K: int, N: int, topk: int, monkeypatch ): """Compare BatchedDeepGemmExperts to BatchedTritonExperts.""" diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 2ef170f1ab30..dab1207d7803 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -248,7 +248,6 @@ def test_fused_moe_batched_experts( per_act_token_quant: bool, block_shape: list[int] | None, input_scales: bool, - workspace_init, ): """Note: float8_e4m3fn is not supported on CUDA architecture < 89, and those tests will be skipped on unsupported hardware.""" diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 53a03f48e24e..b0ff1e64e321 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -137,7 +137,7 @@ def setup_cuda(): @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe( - M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init + M, N, K, E, topk, block_size, dtype, seed, monkeypatch ): if topk > E: pytest.skip(f"Skipping test; topk={topk} > E={E}") diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 0160694d7bb5..c15837f14570 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -274,7 +274,6 @@ def test_cutlass_moe_8_bit_no_graph( per_act_token: bool, per_out_ch: bool, monkeypatch, - workspace_init, ep_size: int | None = None, ): current_platform.seed_everything(7) @@ -330,7 +329,6 @@ def test_cutlass_moe_8_bit_cuda_graph( per_act_token: bool, per_out_ch: bool, monkeypatch, - workspace_init, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") @@ -387,19 +385,9 @@ def test_cutlass_moe_8_bit_EP( per_out_channel: bool, ep_size: int, monkeypatch, - workspace_init, ): test_cutlass_moe_8_bit_no_graph( - m, - n, - k, - e, - topk, - per_act_token, - per_out_channel, - monkeypatch, - workspace_init, - ep_size, + m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size ) @@ -431,19 +419,9 @@ def test_cutlass_moe_8_bit_EP_large( per_out_channel: bool, ep_size: int, monkeypatch, - workspace_init, ): test_cutlass_moe_8_bit_no_graph( - m, - n, - k, - e, - topk, - per_act_token, - per_out_channel, - monkeypatch, - workspace_init, - ep_size, + m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size ) @@ -467,7 +445,6 @@ def test_run_cutlass_moe_fp8( per_act_token: bool, per_out_channel: bool, ep_size: int, - workspace_init, ): current_platform.seed_everything(7) with set_current_vllm_config(vllm_config): diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index f427734ef09e..455ecacef5ec 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -29,7 +29,6 @@ is_deep_gemm_supported, ) from vllm.utils.import_utils import has_deep_ep, has_deep_gemm -from vllm.v1.worker.workspace import init_workspace_manager from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -364,9 +363,6 @@ def _test_deepep_deepgemm_moe( w1_scale: torch.Tensor, w2_scale: torch.Tensor, ): - device = torch.device(f"cuda:{pgi.local_rank}") - init_workspace_manager(device) - current_platform.seed_everything(pgi.rank) w1 = w1.to(device=torch.cuda.current_device()) @@ -449,7 +445,6 @@ def test_ht_deepep_deepgemm_moe( topk: int, world_dp_size: tuple[int, int], disable_deepgemm_ue8m0, - workspace_init, ): """ Tests for High-Throughput DeepEP + DeepGemm integration. @@ -523,7 +518,6 @@ def test_ll_deepep_deepgemm_moe( block_size: list[int], world_dp_size: tuple[int, int], disable_deepgemm_ue8m0, - workspace_init, ): """ Tests for Low-Latency DeepEP + DeepGemm integration. diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index e698ca92a151..d78b8250463a 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -22,7 +22,6 @@ ) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_ep -from vllm.v1.worker.workspace import init_workspace_manager from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -343,9 +342,6 @@ def _deep_ep_moe( use_fp8_dispatch: bool, per_act_token_quant: bool, ): - device = torch.device(f"cuda:{pgi.local_rank}") - init_workspace_manager(device) - if not low_latency_mode: assert not use_fp8_dispatch, ( "FP8 dispatch interface is available only in low-latency mode" @@ -441,7 +437,6 @@ def test_deep_ep_moe( topk: int, world_dp_size: tuple[int, int], per_act_token_quant: bool, - workspace_init, ): low_latency_mode = False use_fp8_dispatch = False @@ -497,7 +492,6 @@ def test_low_latency_deep_ep_moe( topk: int, world_dp_size: tuple[int, int], use_fp8_dispatch: bool, - workspace_init, ): low_latency_mode = True diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index 442b561f8f31..9b1054f7d0ab 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -143,7 +143,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size): @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels") -def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_init): +def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch): with monkeypatch.context() as mp: mp.setenv("VLLM_USE_DEEP_GEMM", "1") diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index bf4ef2d30466..33040b9ad072 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -220,7 +220,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( topk: int, activation: str, monkeypatch, - workspace_init, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 133a8a4a30a6..b2be03ecee2f 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -51,14 +51,7 @@ @pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"]) @torch.inference_mode() def test_flashinfer_fp4_moe_no_graph( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, - activation: str, - workspace_init, + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str ): current_platform.seed_everything(7) with set_current_vllm_config( diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index 384f43db479b..98e80ec02977 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -269,7 +269,7 @@ class Case: ) @pytest.mark.parametrize("num_token", [2]) @pytest.mark.parametrize("tp", [1, 2, 4, 8]) -def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init): +def test_equiv(num_token, a_dtype, w_dtype, tp): from triton_kernels.tensor_details import layout if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"): diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 6ebf1016c166..2a30ef235552 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -16,7 +16,6 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.torch_utils import cuda_device_count_stateless -from vllm.v1.worker.workspace import init_workspace_manager from .modular_kernel_tools.common import ( Config, @@ -78,10 +77,6 @@ def rank_worker( weights: WeightTensors, verbose: bool, ): - # Initialize workspace manager in child process - device = torch.device(f"cuda:{pgi.local_rank}") - init_workspace_manager(device) - current_platform.seed_everything(pgi.rank) # sanity check @@ -305,7 +300,6 @@ def test_modular_kernel_combinations_singlegpu( chunk_size: int | None, world_size: int, pytestconfig, - workspace_init, ): """Note: float8_e4m3fn is not supported on CUDA architecture < 89, and those tests will be skipped on unsupported hardware.""" diff --git a/tests/kernels/moe/test_modular_oai_triton_moe.py b/tests/kernels/moe/test_modular_oai_triton_moe.py index 1abb08f878b2..c8616f13bbf8 100644 --- a/tests/kernels/moe/test_modular_oai_triton_moe.py +++ b/tests/kernels/moe/test_modular_oai_triton_moe.py @@ -209,7 +209,6 @@ def test_oai_triton_moe( num_experts: int, topk: int, unfused: bool, - workspace_init, ): current_platform.seed_everything(0) ( diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index ce99d9691fdc..82659276af37 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -231,7 +231,6 @@ def test_fused_moe( padding: bool, chunk_size: int, monkeypatch, - workspace_init, ): current_platform.seed_everything(7) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index e67bd76a1618..aa544fe0e0f6 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -40,7 +40,7 @@ @pytest.mark.parametrize("dtype", [torch.bfloat16]) @torch.inference_mode() def test_cutlass_fp4_moe_no_graph( - m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype ): current_platform.seed_everything(7) with set_current_vllm_config( diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 35e554e16cb3..f671b23d300c 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -46,7 +46,6 @@ ) from vllm.platforms import current_platform from vllm.utils.math_utils import round_up -from vllm.v1.worker.workspace import init_workspace_manager from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -182,7 +181,6 @@ def test_fused_moe_batched_experts( e: int, topk: int, dtype: torch.dtype, - workspace_init, ): current_platform.seed_everything(7) @@ -865,9 +863,6 @@ def _pplx_test_loop( make_weights: bool, test_fn: Callable, ): - device = torch.device(f"cuda:{pgi.local_rank}") - init_workspace_manager(device) - def format_result(msg, ex=None): if ex is not None: x = str(ex) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 8049347280c5..b34d587eb362 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -22,14 +22,10 @@ ) from vllm import _custom_ops as ops from vllm.attention.ops import flashmla -from vllm.config import set_current_vllm_config from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.mla.flashmla_sparse import ( - FlashMLASparseBackend, - triton_convert_req_index_to_global_index, -) -from vllm.v1.attention.backends.utils import split_prefill_chunks +from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend +from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks SPARSE_BACKEND_BATCH_SPECS = { name: BATCH_SPECS[name] @@ -118,12 +114,8 @@ def _quantize_dequantize_fp8_ds_mla( @pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) @pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) -@pytest.mark.skipif( - torch.cuda.get_device_capability() < (9, 0), - reason="FlashMLASparseBackend requires CUDA 9.0 or higher", -) def test_sparse_backend_decode_correctness( - dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init + dist_init, batch_name, kv_cache_dtype, tensor_parallel_size ): if not torch.cuda.is_available(): pytest.skip("CUDA is required for sparse MLA decode test") @@ -328,29 +320,28 @@ def test_sparse_backend_decode_correctness( mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous()) impl_cls = FlashMLASparseBackend.get_impl_cls() - with set_current_vllm_config(vllm_config): - impl = impl_cls( - num_heads=num_heads, - head_size=head_size, - scale=scale, - num_kv_heads=1, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype=vllm_config.cache_config.cache_dtype, - logits_soft_cap=None, - attn_type="decoder", - kv_sharing_target_layer_name=None, - q_lora_rank=None, - kv_lora_rank=kv_lora_rank, - qk_nope_head_dim=qk_nope_head_dim, - qk_rope_head_dim=qk_rope_head_dim, - qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, - v_head_dim=v_head_dim, - kv_b_proj=mock_kv_b_proj, - indexer=mock_indexer, - ) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=1, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + indexer=mock_indexer, + ) - impl.process_weights_after_loading(dtype) + impl.process_weights_after_loading(dtype) layer = MockAttentionLayer(device) out_buffer = torch.empty( @@ -375,192 +366,22 @@ def test_sparse_backend_decode_correctness( torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5) -def _triton_convert_reference_impl( - req_ids: torch.Tensor, - block_table: torch.Tensor, - token_indices: torch.Tensor, - block_size: int, - num_topk_tokens: int, - HAS_PREFILL_WORKSPACE: bool = False, - prefill_workspace_request_ids: torch.Tensor | None = None, - prefill_workspace_starts: torch.Tensor | None = None, -) -> torch.Tensor: - """Reference implementation for triton_convert_req_index_to_global_index.""" - num_tokens = req_ids.shape[0] - max_blocks_per_req = block_table.shape[1] - result = torch.empty( - num_tokens, num_topk_tokens, dtype=torch.int32, device=req_ids.device - ) - - for token_id in range(num_tokens): - req_id = req_ids[token_id].item() - - # Determine if this token uses workspace or paged cache - use_prefill_workspace = False - workspace_start = 0 - if HAS_PREFILL_WORKSPACE and prefill_workspace_request_ids is not None: - assert prefill_workspace_starts is not None - prefill_req_id = prefill_workspace_request_ids[token_id].item() - if prefill_req_id >= 0: - use_prefill_workspace = True - workspace_start = prefill_workspace_starts[prefill_req_id].item() - - for idx_id in range(num_topk_tokens): - token_idx = token_indices[token_id, idx_id].item() - - if token_idx == -1: - result[token_id, idx_id] = -1 - elif use_prefill_workspace: - # Prefill + using prefill workspace: map to workspace offset - result[token_id, idx_id] = workspace_start + token_idx - else: - # Decode: map to paged cache - block_id = token_idx // block_size - if block_id >= max_blocks_per_req: - result[token_id, idx_id] = -1 - else: - block_num = block_table[req_id, block_id].item() - offset = token_idx % block_size - result[token_id, idx_id] = block_num * block_size + offset - - return result - - -@pytest.mark.parametrize("block_size", [16, 64, 128]) -@pytest.mark.parametrize("num_topk_tokens", [128, 256, 512]) -@pytest.mark.skipif( - torch.cuda.get_device_capability() < (9, 0), - reason="FlashMLASparseBackend requires CUDA 9.0 or higher", -) -def test_triton_convert_req_index_to_global_index_decode_only( - block_size, num_topk_tokens -): - device = torch.device("cuda") - num_tokens = 8 - num_requests = 4 - max_blocks_per_req = 10 - - req_id = torch.randint( - 0, num_requests, (num_tokens,), dtype=torch.int32, device=device - ) - block_table = torch.randint( - 0, 100, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device - ) - - token_indices = torch.randint( - 0, - block_size * max_blocks_per_req, - (num_tokens, num_topk_tokens), - dtype=torch.int32, - device=device, - ) - - # Set some to -1 to test masking - token_indices[0, :10] = -1 - token_indices[3, 50:60] = -1 - - # Set some to out of bounds - token_indices[2, 100:110] = max_blocks_per_req * block_size - token_indices[6, 150:160] = max_blocks_per_req * block_size - - result = triton_convert_req_index_to_global_index( - req_id, - block_table, - token_indices, - BLOCK_SIZE=block_size, - NUM_TOPK_TOKENS=num_topk_tokens, - ) - - reference_result = _triton_convert_reference_impl( - req_id, - block_table, - token_indices, - block_size, - num_topk_tokens, - ) - - torch.testing.assert_close(result, reference_result, rtol=0, atol=0) - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.skipif( - torch.cuda.get_device_capability() < (9, 0), - reason="FlashMLASparseBackend requires CUDA 9.0 or higher", -) -def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size): - device = torch.device("cuda") - num_requests = 4 - max_blocks_per_req = 8 - num_topk_tokens = 128 - - # First 6 tokens are decode (reqs 0, 1), last 6 are prefill (reqs 2, 3) - req_id = torch.tensor( - [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=torch.int32, device=device - ) - prefill_workspace_request_ids = torch.tensor( - [-1, -1, -1, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=torch.int32, device=device - ) - - # Workspace starts for the 2 prefill reqs: req 2 starts at 0, req 3 starts at 100 - prefill_workspace_starts = torch.tensor([0, 100], dtype=torch.int32, device=device) - - block_table = torch.randint( - 0, 50, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device - ) - token_indices = torch.randint( - 0, - block_size * max_blocks_per_req, - (req_id.shape[0], num_topk_tokens), - dtype=torch.int32, - device=device, - ) - - # Set some to -1 to test masking - token_indices[0, :10] = -1 - token_indices[3, 50:60] = -1 - - # Set some to out of bounds - token_indices[2, 100:110] = max_blocks_per_req * block_size - token_indices[6, 150:160] = max_blocks_per_req * block_size - - result = triton_convert_req_index_to_global_index( - req_id, - block_table, - token_indices, - BLOCK_SIZE=block_size, - NUM_TOPK_TOKENS=num_topk_tokens, - HAS_PREFILL_WORKSPACE=True, - prefill_workspace_request_ids=prefill_workspace_request_ids, - prefill_workspace_starts=prefill_workspace_starts, - ) - - reference_result = _triton_convert_reference_impl( - req_id, - block_table, - token_indices, - block_size, - num_topk_tokens, - HAS_PREFILL_WORKSPACE=True, - prefill_workspace_request_ids=prefill_workspace_request_ids, - prefill_workspace_starts=prefill_workspace_starts, - ) - - torch.testing.assert_close(result, reference_result, rtol=0, atol=0) - - @pytest.mark.parametrize( - "seq_lens,max_buf,expected", + "seq_lens,max_buf,start,expected", [ # Basic split: totals per chunk ≤ max_buf - (torch.tensor([2, 3, 4, 2]), 5, [(0, 2), (2, 3), (3, 4)]), - # Exact fits should split between items when adding the next would overflow - (torch.tensor([5, 5, 5]), 5, [(0, 1), (1, 2), (2, 3)]), + (torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]), + # Non-zero start index + (torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]), + # Exact fits should split between items when adding the next would + # overflow + (torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]), # All requests fit in a single chunk - (torch.tensor([1, 1, 1]), 10, [(0, 3)]), - # Large buffer - (torch.tensor([4, 4, 4]), 100, [(0, 3)]), + (torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]), + # Large buffer with non-zero start + (torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]), ], ) -def test_split_prefill_chunks(seq_lens, max_buf, expected): - out = split_prefill_chunks(seq_lens, max_buf) +def test_split_prefill_chunks(seq_lens, max_buf, start, expected): + out = split_prefill_chunks(seq_lens, max_buf, start) assert out == expected diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 2319655008c5..6c94b9e131e1 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2403,29 +2403,6 @@ def cp_gather_cache( ) -def cp_gather_and_upconvert_fp8_kv_cache( - src_cache: torch.Tensor, - dst: torch.Tensor, - block_table: torch.Tensor, - seq_lens: torch.Tensor, - workspace_starts: torch.Tensor, - batch_size: int, -) -> None: - """Gather and upconvert FP8 KV cache to BF16 workspace. - - Args: - src_cache: FP8 KV cache [num_blocks, block_size, 656] - dst: BF16 output workspace [total_tokens, 576] - block_table: Block indices [num_reqs, max_blocks] - seq_lens: Sequence lengths [num_reqs] - workspace_starts: Workspace start offsets [num_reqs] - batch_size: Number of requests - """ - torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache( - src_cache, dst, block_table, seq_lens, workspace_starts, batch_size - ) - - def indexer_k_quant_and_cache( k: torch.Tensor, kv_cache: torch.Tensor, diff --git a/vllm/envs.py b/vllm/envs.py index d0f279809626..cb75ba1a62de 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -239,7 +239,6 @@ VLLM_NCCL_INCLUDE_PATH: str | None = None VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" - VLLM_DEBUG_WORKSPACE: bool = False VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" @@ -1538,9 +1537,6 @@ def get_vllm_port() -> int | None: # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with # top 5 collected objects "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), - # Debug workspace allocations. - # logging of workspace resize operations. - "VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))), # Disables parallel execution of shared_experts via separate cuda stream "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool( int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0")) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 484314091cb1..12b2edf0f6fa 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -24,12 +24,12 @@ from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.v1.worker.ubatching import ( + dbo_current_ubatch_id, dbo_enabled, dbo_maybe_run_recv_hook, dbo_register_recv_hook, dbo_yield, ) -from vllm.v1.worker.workspace import current_workspace_manager logger = init_logger(__name__) @@ -663,6 +663,25 @@ def _slice_scales( return None +class SharedResizableBuffer: + def __init__(self): + self.buffer = None + + def get( + self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype + ) -> torch.Tensor: + assert shape != () + shape_numel = prod(shape) + if ( + self.buffer is None + or self.buffer.numel() < shape_numel + or self.buffer.device != device + or self.buffer.dtype != dtype + ): + self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) + return self.buffer[:shape_numel].view(*shape) + + @final class FusedMoEModularKernel(torch.nn.Module): """ @@ -677,6 +696,22 @@ class FusedMoEModularKernel(torch.nn.Module): objects. """ + class SharedBuffers: + def __init__(self) -> None: + self.fused_out = SharedResizableBuffer() + self.workspace13 = SharedResizableBuffer() + self.workspace2 = SharedResizableBuffer() + + # Persistent buffers that are shared across `FusedMoEModularKernel` + # instances (layers), to save memory and allocattions. + # + # We have two sets of buffers to support dual batch overlap (DBO) where each + # microbatch (ubatch) should use its own set of buffers to avoid + # cross-ubatch contimination. + # NOTE that memory is lazily allocated for these buffers, meaning that if + # DBO isn't being used, the second SharedBuffers will be empty. + shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()] + def __init__( self, prepare_finalize: FusedMoEPrepareAndFinalize, @@ -776,6 +811,10 @@ def _allocate_buffers( assert M_full > 0 and M_chunk > 0 num_chunks, _ = self._chunk_info(M_full) + + # select per-ubatch buffers to avoid cross-ubatch reuse under DBO + ubatch_idx = dbo_current_ubatch_id() + buffers = self.shared_buffers[ubatch_idx] workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) # Force worst-case allocation in profiling run for @@ -798,11 +837,14 @@ def _allocate_buffers( expert_tokens_meta, ) ) - - current_workspace_manager().get_simultaneous( - (max_workspace_13, workspace_dtype), - (max_workspace_2, workspace_dtype), - (max_fused_out_shape, out_dtype), + buffers.workspace13.get( + max_workspace_13, device=device, dtype=workspace_dtype + ) + buffers.workspace2.get( + max_workspace_2, device=device, dtype=workspace_dtype + ) + buffers.fused_out.get( + max_fused_out_shape, device=device, dtype=workspace_dtype ) # Get intermediate workspace shapes based off the chunked M size. @@ -829,23 +871,22 @@ def _allocate_buffers( # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. + workspace13 = buffers.workspace13.get( + workspace13_shape, device=device, dtype=workspace_dtype + ) + workspace2 = buffers.workspace2.get( + workspace2_shape, device=device, dtype=workspace_dtype + ) + # Construct the entire output that can then be processed in chunks. # Reuse workspace13 for the output in the non-chunked case as long # as it is large enough. This will not always be the case for standard # format experts and with experts that have empty workspaces. if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape): - workspace13, workspace2 = current_workspace_manager().get_simultaneous( - (workspace13_shape, workspace_dtype), - (workspace2_shape, workspace_dtype), - ) fused_out = _resize_cache(workspace13, fused_out_shape) else: - workspace13, workspace2, fused_out = ( - current_workspace_manager().get_simultaneous( - (workspace13_shape, workspace_dtype), - (workspace2_shape, workspace_dtype), - (fused_out_shape, out_dtype), - ) + fused_out = buffers.fused_out.get( + fused_out_shape, device=device, dtype=out_dtype ) return workspace13, workspace2, fused_out diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 146124153c79..a9fa76deecbd 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -83,7 +83,6 @@ DeepseekV32IndexerMetadata, ) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -from vllm.v1.worker.workspace import current_workspace_manager from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP from .utils import ( @@ -617,15 +616,8 @@ def sparse_attn_indexer( # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata fp8_dtype = current_platform.fp8_dtype() - # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): - # Reserve workspace for indexer during profiling run - current_workspace_manager().get_simultaneous( - ((total_seq_lens, head_dim), torch.float8_e4m3fn), - ((total_seq_lens, 4), torch.uint8), - ) - return sparse_attn_indexer_fake( hidden_states, k_cache_prefix, @@ -659,17 +651,17 @@ def sparse_attn_indexer( topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill - - # Get the full shared workspace buffers once (will allocate on first use) - workspace_manager = current_workspace_manager() - k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( - ((total_seq_lens, head_dim), fp8_dtype), - ((total_seq_lens, 4), torch.uint8), - ) - for chunk in prefill_metadata.chunks: - k_fp8 = k_fp8_full[: chunk.total_seq_lens] - k_scale = k_scale_full[: chunk.total_seq_lens] + k_fp8 = torch.empty( + [chunk.total_seq_lens, head_dim], + device=k.device, + dtype=fp8_dtype, + ) + k_scale = torch.empty( + [chunk.total_seq_lens, 4], + device=k.device, + dtype=torch.uint8, + ) ops.cp_gather_indexer_k_quant_cache( kv_cache, k_fp8, @@ -785,6 +777,15 @@ def sparse_attn_indexer_fake( total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, ) -> torch.Tensor: + # profile run + # NOTE(Chen): create the max possible flattened_kv. So that + # profile_run can get correct memory usage. + _flattened_kv = torch.empty( + [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 + ) + fp8_dtype = current_platform.fp8_dtype() + _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() + _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() return topk_indices_buffer diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 0818078da036..9bf6b5ef0389 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -18,7 +18,7 @@ flash_mla_with_kvcache, get_mla_metadata, ) -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import VllmConfig from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms import current_platform @@ -30,31 +30,13 @@ AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - reshape_attn_output_for_spec_decode, - reshape_query_for_spec_decode, - split_decodes_and_prefills, - split_prefill_chunks, ) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.workspace import current_workspace_manager if TYPE_CHECKING: from vllm.model_executor.models.deepseek_v2 import Indexer logger = init_logger(__name__) - -# For FP8 sparse attention we have two impelementations: -# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is -# done by treating all tokens as single batch. -# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill -# (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using -# the FP8 decode kernel for decode. -# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16 -# prefill kernel requires padding the numer of heads to 128 while the decode does not -# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed -# batch mode (#2). -MIN_HEADS_FOR_BF16_PREFILL = 32 - """ NOTE: FlashMLA Sparse uses an fp8 cache with the following format @@ -145,72 +127,19 @@ class FP8KernelMetadata: dummy_block_table: torch.Tensor cache_lens: torch.Tensor - @dataclass - class FP8SeperatePrefillDecode: - @dataclass - class Decode: - kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata" - decode_query_len: int # needed for reshape in spec decode - - @dataclass - class Prefill: - # Sequence lengths (context + query) for prefill requests - # Shape: [num_prefill_reqs] - seq_lens: torch.Tensor - - # Request ID for each token: -1 for decode tokens, request index - # (0, 1, 2, ...) for prefill tokens. - # Shape: [num_actual_tokens] - request_ids: torch.Tensor - - # Workspace start offsets for all prefill requests - # Shape: [num_prefill_reqs], adjusted in-place per chunk to be - # 0-indexed within each chunk. Used to map prefill tokens to workspace - # offsets in convert_logical_index_to_physical_index - workspace_starts: torch.Tensor - - @dataclass - class Chunk: - """Metadata for a chunk of prefill requests. - - Prefill requests may be chunked to fit within the fixed workspace size. - """ - - seq_lens: torch.Tensor - tokens_slice: slice - block_table: torch.Tensor - req_start_idx: int - workspace_starts: torch.Tensor - chunk_tot_seqlen: int - - chunks: list[Chunk] - - num_prefills: int = 0 - num_decodes: int = 0 - num_prefill_tokens: int = 0 - num_decode_tokens: int = 0 - - decode: Decode | None = None - prefill: Prefill | None = None - - fp8_extra_metadata: FP8SeperatePrefillDecode | FP8KernelMetadata | None = None - fp8_use_mixed_batch: bool = False - - -# Kernel with prefill workspace support + fp8_extra_metadata: FP8KernelMetadata | None = None + + @triton.jit def _convert_req_index_to_global_index_kernel( req_id_ptr, # int32 [num_tokens] block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] - prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill - workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr # shapes (compile-time where possible) max_num_blocks_per_req: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, # tile width along columns - HAS_PREFILL: tl.constexpr, # strides (in elements) bt_stride0, bt_stride1, @@ -236,10 +165,7 @@ def _convert_req_index_to_global_index_kernel( # Only token == -1 should propagate as -1 is_invalid_tok = tok < 0 - is_prefill = False - if HAS_PREFILL: - prefill_req_id = tl.load(prefill_request_id_ptr + token_id) - is_prefill = prefill_req_id >= 0 + # Compute block id and in-block offset block_id = tok // BLOCK_SIZE inblock_off = tok % BLOCK_SIZE @@ -247,18 +173,12 @@ def _convert_req_index_to_global_index_kernel( # Guard block_table access valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 - is_invalid_tok |= ~valid_block - base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) - out_val = base * BLOCK_SIZE + inblock_off - - # Override with prefill output if prefill is enabled - if HAS_PREFILL: - workspace_start = tl.load( - workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 - ) - prefill_out = workspace_start + tok - out_val = tl.where(is_prefill, prefill_out, out_val) - out_val = tl.where(is_invalid_tok, -1, out_val) + base = tl.load(bt_ptr, mask=valid_block, other=0) + + # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset + out_val = tl.where( + is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off + ) # Store results out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 @@ -272,9 +192,6 @@ def triton_convert_req_index_to_global_index( BLOCK_SIZE: int = 64, NUM_TOPK_TOKENS: int = 2048, BLOCK_N: int = 128, # tile width along columns - HAS_PREFILL_WORKSPACE: bool = False, - prefill_workspace_request_ids: torch.Tensor | None = None, - prefill_workspace_starts: torch.Tensor | None = None, ): """ out[token_id, indice_id] = @@ -285,32 +202,17 @@ def triton_convert_req_index_to_global_index( Only when token_indices[token_id, indice_id] == -1 do we output -1. For safety, we also output -1 if the derived block_id would be out-of-bounds. - - When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets - instead of global cache slots. prefill_workspace_request_ids and - prefill_workspace_starts must be provided. - - prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else - prefill request index (maps to prefill_workspace_starts) - prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace - starts for each prefill request """ assert req_id.dtype == torch.int32 assert block_table.dtype == torch.int32 assert token_indices.dtype == torch.int32 assert token_indices.shape[1] == NUM_TOPK_TOKENS assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( - f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})" ) - if HAS_PREFILL_WORKSPACE: - assert prefill_workspace_request_ids is not None - assert prefill_workspace_starts is not None - assert prefill_workspace_request_ids.dtype == torch.int32 - assert prefill_workspace_starts.dtype == torch.int32 - num_tokens = req_id.shape[0] - max_num_blocks_per_req = block_table.shape[1] + num_requests, max_num_blocks_per_req = block_table.shape tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N # Ensure contiguous tensors on the same device @@ -324,13 +226,6 @@ def triton_convert_req_index_to_global_index( ti_stride0, ti_stride1 = token_indices_c.stride() out_stride0, out_stride1 = out.stride() - # Prepare prefill pointers - if HAS_PREFILL_WORKSPACE: - assert prefill_workspace_request_ids is not None # for mypy - assert prefill_workspace_starts is not None # for mypy - assert prefill_workspace_request_ids.is_contiguous() - assert prefill_workspace_starts.is_contiguous() - # Exact 2D grid: tokens × column tiles grid = (num_tokens, tiles_per_row) @@ -339,13 +234,10 @@ def triton_convert_req_index_to_global_index( block_table_c, token_indices_c, out, - prefill_workspace_request_ids, - prefill_workspace_starts, # shapes / constexprs max_num_blocks_per_req, BLOCK_SIZE, BLOCK_N, - HAS_PREFILL_WORKSPACE, # strides bt_stride0, bt_stride1, @@ -357,16 +249,7 @@ def triton_convert_req_index_to_global_index( return out -def get_prefill_workspace_size(max_model_len: int): - # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size. - # May be tuned later. - # Memory usage: 5 * max_model_len * 576 * 2 bytes - # Example: DeepSeek-V3.2 with max_model_len=163840 -> - # 5 * 163840 * 576 * 2 = ~900 MB - # This fits nicely below the typical MoE workspace size of >2GB so this is "free" - return max_model_len * 5 - - +@dataclass class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH @@ -376,42 +259,29 @@ def __init__( layer_names: list[str], vllm_config: VllmConfig, device: torch.device, - ) -> None: - self.vllm_config = vllm_config - self.layer_names = layer_names + ): cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.device = device - # Treat requests with query length <= 1 as decodes to match the - # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2) - self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) - props = torch.cuda.get_device_properties(device) sm_count = props.multi_processor_count self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) - self.topk_tokens = vllm_config.model_config.hf_config.index_topk self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" - max_num_seqs = vllm_config.scheduler_config.max_num_seqs - # Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG) - self.topk_tokens_tensor = torch.full( - (max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32 + self.topk_tokens_tensor = torch.tensor( + [self.topk_tokens], device=device, dtype=torch.int32 ) - # Shape: [max_num_seqs], all elements = max_model_len - self.max_model_len_tensor = torch.full( - (max_num_seqs,), - self.model_config.max_model_len, - device=device, - dtype=torch.int32, + self.max_model_len_tensor = torch.tensor( + [self.model_config.max_model_len], device=device, dtype=torch.int32 ) # this is ignored by `flash_mla_with_kvcache` if indices not None self.dummy_block_table = torch.empty( - (max_num_seqs, 1), dtype=torch.int32, device=self.device + (1, 1), dtype=torch.int32, device=self.device ) # Equation taken from FlashMLA/csrc/pybind.cpp @@ -429,9 +299,10 @@ def __init__( dtype=torch.int32, device=device, ) - # Sized for per-request batching (num_decodes + 1) self.num_splits_buffer = torch.empty( - (max_num_seqs + 1,), + # We pack all the tokens into one batch for sparse attention. + # Otherwise, we can exceed the sm of `get_mla_metadata`. + (2,), dtype=torch.int32, device=device, ) @@ -441,171 +312,30 @@ def __init__( device=device, ) - def _build_fp8_mixed_decode_prefill( - self, - common_attn_metadata: CommonAttentionMetadata, - ) -> "FlashMLASparseMetadata.FP8KernelMetadata": - """Build FP8 metadata treating all tokens as one mixed batch. - - This matches main branch's approach and avoids the BF16 prefill kernel - which has head padding overhead when num_heads is small (high TP case). - """ - num_tokens = common_attn_metadata.num_actual_tokens - - # Build metadata for all tokens as a single batch - tile_scheduler_metadata, num_splits = get_mla_metadata( - cache_seqlens=self.topk_tokens_tensor[:1], # Single batch - num_q_tokens_per_head_k=num_tokens * self.num_heads, - topk=self.topk_tokens, - num_heads_q=self.num_heads, - num_heads_k=1, - is_fp8_kvcache=True, - ) - - num_sm_parts = tile_scheduler_metadata.size(0) - tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[ - :num_sm_parts - ] - tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) - num_splits_view = self.num_splits_buffer[:2] - num_splits_view.copy_(num_splits) - - fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata( - scheduler_metadata=tile_scheduler_metadata_buffer, - num_splits=num_splits_view, - cache_lens=self.max_model_len_tensor[:1], - dummy_block_table=self.dummy_block_table[:1], - ) - - return fp8_metadata - - def _build_fp8_separate_prefill_decode( + def build( self, + common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - ) -> "FlashMLASparseMetadata.FP8SeperatePrefillDecode": + fast_build: bool = False, + ) -> FlashMLASparseMetadata: num_tokens = common_attn_metadata.num_actual_tokens - - (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( - split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold or 1, - require_uniform=True, - ) + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) + seg_lengths = np.diff(starts) + req_id_per_token = np.repeat( + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths ) - - FP8Meta = FlashMLASparseMetadata.FP8SeperatePrefillDecode - fp8_metadata = FP8Meta( - num_decodes=num_decodes, - num_prefills=num_prefills, - num_decode_tokens=num_decode_tokens, - num_prefill_tokens=num_prefill_tokens, + # Zero-fill for cudagraphs + self.req_id_per_token_buffer.fill_(0) + self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( + torch.from_numpy(req_id_per_token), non_blocking=True ) + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] - # Extract prefill sequence lengths (context + query, not just query) - # Decode requests come first in the batch, prefill requests follow - prefill_seq_lens = None - prefill_request_id = None - prefill_workspace_starts = None - prefill_chunks = None - - # For pure decode batches, prefill_request_id will be None - # For mixed batches, it will have -1 for decode and request_id for prefill - if num_prefills > 0: - seq_lens_cpu = common_attn_metadata.seq_lens_cpu - seq_lens = common_attn_metadata.seq_lens - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - - prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:] - prefill_seq_lens = seq_lens[num_decodes:] - - # Build prefill_request_id: -1 for decode, request index for - # prefill. This enables a single - # convert_logical_index_to_physical_index call for all tokens - prefill_request_id = torch.full( - (num_tokens,), -1, dtype=torch.int32, device=self.device - ) - # Map prefill tokens to their request IDs (0, 1, 2, ...) - for req_idx in range(num_prefills): - # Get query token range for this prefill request - global_req_idx = num_decodes + req_idx - req_query_start = query_start_loc_cpu[global_req_idx] - req_query_end = query_start_loc_cpu[global_req_idx + 1] - prefill_request_id[req_query_start:req_query_end] = req_idx - - # will be adjusted by chunk loop - prefill_workspace_starts_cpu = torch.zeros( - num_prefills, dtype=torch.int32, pin_memory=True - ) - prefill_workspace_starts_cpu[1:] = torch.cumsum( - prefill_seq_lens_cpu[:-1], dim=0 - ) - # populated by non-blocking copy after prefill_workspace_starts_cpu is - # updated by each chunk - prefill_workspace_starts = torch.empty( - num_prefills, dtype=torch.int32, device=self.device - ) - - # Chunk prefill requests to fit within workspace size - max_prefill_buffer_size = get_prefill_workspace_size( - self.vllm_config.model_config.max_model_len - ) - chunk_bounds = split_prefill_chunks( - prefill_seq_lens_cpu, max_prefill_buffer_size - ) - - prefill_chunks = [] - for chunk_start, chunk_end in chunk_bounds: - # Adjust workspace_starts in-place per chunk to be - # 0-indexed within each chunk - # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]] - # Initial: workspace_starts=[0,10,25,45] - # After: workspace_starts=[0,10,0,20] - # (chunk 0 starts at 0, chunk 1 starts at 0) - offset = prefill_workspace_starts_cpu[chunk_start].item() - prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset - - chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end] - chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum() - token_start = query_start_loc_cpu[num_decodes + chunk_start].item() - token_end = query_start_loc_cpu[num_decodes + chunk_end].item() - tokens_slice = slice(token_start, token_end) - - # Create chunk view of gpu tensor - chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end] - chunk_block_table = common_attn_metadata.block_table_tensor[ - num_decodes + chunk_start : num_decodes + chunk_end - ] - - prefill_chunks.append( - FP8Meta.Prefill.Chunk( - seq_lens=chunk_seq_lens, - tokens_slice=tokens_slice, - block_table=chunk_block_table, - req_start_idx=chunk_start, - workspace_starts=chunk_workspace_starts, - chunk_tot_seqlen=chunk_tot_seqlen, - ) - ) - - prefill_workspace_starts.copy_( - prefill_workspace_starts_cpu, non_blocking=True - ) - - fp8_metadata.prefill = FP8Meta.Prefill( - seq_lens=prefill_seq_lens, - request_ids=prefill_request_id, - workspace_starts=prefill_workspace_starts, - chunks=prefill_chunks, - ) - - if num_decodes > 0: - # Compute decode_query_len for spec decode (uniform due to require_uniform) - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item() - + fp8_extra_metadata = None + if self.use_fp8_kv_cache: tile_scheduler_metadata, num_splits = get_mla_metadata( - cache_seqlens=self.topk_tokens_tensor[:num_decodes], - num_q_tokens_per_head_k=decode_query_len * self.num_heads, + cache_seqlens=self.topk_tokens_tensor, + num_q_tokens_per_head_k=num_tokens * self.num_heads, topk=self.topk_tokens, num_heads_q=self.num_heads, num_heads_k=1, @@ -618,70 +348,33 @@ def _build_fp8_separate_prefill_decode( :num_sm_parts ] tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) - # num_splits has size [num_decodes + 1] - num_splits_view = self.num_splits_buffer[: num_decodes + 1] - num_splits_view.copy_(num_splits) + self.num_splits_buffer.copy_(num_splits) - kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata( + fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata( scheduler_metadata=tile_scheduler_metadata_buffer, - num_splits=num_splits_view, - dummy_block_table=self.dummy_block_table[:num_decodes], - cache_lens=self.max_model_len_tensor[:num_decodes], - ) - fp8_metadata.decode = FP8Meta.Decode( - kernel_metadata=kernel_meta, - decode_query_len=decode_query_len, + num_splits=self.num_splits_buffer, + # cache_lens and block_table are basically unused in sparse case + # but the decode kernel will treat -1 and indices >= cache_lens + # as invalid so we make sure cache_lens is large enough to not + # accidentally mark indices invalid, we will use -1 exclusively + # to mark invalid indices + cache_lens=self.max_model_len_tensor, + dummy_block_table=self.dummy_block_table, ) - return fp8_metadata - - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ) -> FlashMLASparseMetadata: - cm = common_attn_metadata - num_tokens = cm.num_actual_tokens - starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32) - seg_lengths = np.diff(starts) - req_id_per_token = np.repeat( - np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths - ) - # Zero-fill for cudagraphs - self.req_id_per_token_buffer.fill_(0) - self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( - torch.from_numpy(req_id_per_token), non_blocking=True - ) - req_id_per_token = self.req_id_per_token_buffer[:num_tokens] - - fp8_extra_metadata: ( - FlashMLASparseMetadata.FP8SeperatePrefillDecode - | FlashMLASparseMetadata.FP8KernelMetadata - | None - ) = None - fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL - if self.use_fp8_kv_cache: - if fp8_use_mixed_batch: - fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm) - else: - fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm) - metadata = FlashMLASparseMetadata( - num_reqs=cm.num_reqs, - max_query_len=cm.max_query_len, - max_seq_len=cm.max_seq_len, - num_actual_tokens=cm.num_actual_tokens, - query_start_loc=cm.query_start_loc, - slot_mapping=cm.slot_mapping, - block_table=cm.block_table_tensor, + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + block_table=common_attn_metadata.block_table_tensor, req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, topk_tokens=self.topk_tokens, fp8_extra_metadata=fp8_extra_metadata, - fp8_use_mixed_batch=fp8_use_mixed_batch, ) - return metadata @@ -721,204 +414,12 @@ def __init__( self.topk_indices_buffer = indexer.topk_indices_buffer self.padding = 128 if current_platform.is_device_capability_family(100) else 64 - if kv_cache_dtype == "fp8_ds_mla": - # Reserve workspace during initialization - vllm_config = get_current_vllm_config() - assert vllm_config is not None and vllm_config.model_config is not None - prefill_workspace_size = get_prefill_workspace_size( - vllm_config.model_config.max_model_len - ) - self.prefill_workspace_shape = (prefill_workspace_size, head_size) - (self.prefill_bf16_workspace,) = ( - current_workspace_manager().get_simultaneous( - (self.prefill_workspace_shape, torch.bfloat16) - ) - ) - def _forward_bf16_kv( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, topk_indices: torch.Tensor, attn_metadata: FlashMLASparseMetadata, - ) -> torch.Tensor: - # Convert per-request indices to global slots (decode) or workspace - # offsets (prefill). - topk_indices = triton_convert_req_index_to_global_index( - attn_metadata.req_id_per_token, - attn_metadata.block_table, - topk_indices, - BLOCK_SIZE=attn_metadata.block_size, - NUM_TOPK_TOKENS=topk_indices.shape[1], - ) - - return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices) - - def _forward_fp8_kv_separate_prefill_decode( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, - attn_metadata: FlashMLASparseMetadata, - ) -> torch.Tensor: - fp8_metadata = attn_metadata.fp8_extra_metadata - assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) - num_decodes = fp8_metadata.num_decodes - - prefill_request_ids = None - prefill_workspace_starts = None - has_prefill_workspace = False - if fp8_metadata.prefill is not None: - prefill_request_ids = fp8_metadata.prefill.request_ids - prefill_workspace_starts = fp8_metadata.prefill.workspace_starts - has_prefill_workspace = True - - # Convert per-request indices to global slots (decode) or workspace - # offsets (prefill). - # For FP8 cache: prefill uses workspace mapping (upconverted to BF16) - # For BF16 cache: always use global cache slots (no workspace) - # prefill_workspace_starts has been adjusted in-place per chunk so - # prefill indices automatically come out chunk-local - topk_indices = triton_convert_req_index_to_global_index( - attn_metadata.req_id_per_token, - attn_metadata.block_table, - topk_indices, - BLOCK_SIZE=attn_metadata.block_size, - NUM_TOPK_TOKENS=topk_indices.shape[1], - HAS_PREFILL_WORKSPACE=has_prefill_workspace, - prefill_workspace_request_ids=prefill_request_ids, - prefill_workspace_starts=prefill_workspace_starts, - ) - - fp8_metadata = attn_metadata.fp8_extra_metadata - assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) - - def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: - # Reshape q: (num_decode_tokens, num_heads, head_dim) - # -> (num_decodes, seq_len, num_heads, head_dim) - q = reshape_query_for_spec_decode(q, num_decodes) - seq_len = q.shape[1] - # Reshape topk_indices: (num_decode_tokens, topk) - # -> (num_decodes, seq_len, topk) - topk_indices = topk_indices.view(num_decodes, seq_len, -1) - assert fp8_metadata.decode is not None - attn_out, _ = self._fp8_flash_mla_kernel( - q=q, - kv_c_and_k_pe_cache=kv_c_and_k_pe_cache, - topk_indices=topk_indices, - kernel_metadata=fp8_metadata.decode.kernel_metadata, - ) - # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v) - # -> (num_decode_tokens, num_heads, head_dim_v) - return reshape_attn_output_for_spec_decode(attn_out) - - num_decode_tokens = fp8_metadata.num_decode_tokens - num_prefill_tokens = fp8_metadata.num_prefill_tokens - - # Pure decode: direct call without allocation - if num_decode_tokens > 0 and num_prefill_tokens == 0: - assert fp8_metadata.decode is not None - attn_out = _fp8_decode(q, topk_indices) - else: - # Mixed or pure prefill: allocate output tensor - attn_out = q.new_empty( - (attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank), - dtype=q.dtype, - device=q.device, - ) - - if num_decode_tokens > 0: - attn_out[:num_decode_tokens] = _fp8_decode( - q[:num_decode_tokens], topk_indices[:num_decode_tokens] - ) - - assert fp8_metadata.prefill is not None - for chunk in fp8_metadata.prefill.chunks: - chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen] - ops.cp_gather_and_upconvert_fp8_kv_cache( - kv_c_and_k_pe_cache, - chunk_workspace, - chunk.block_table, - chunk.seq_lens, - chunk.workspace_starts, - len(chunk.block_table), - ) - - chunk_q = q[chunk.tokens_slice] - chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice] - - attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel( - chunk_q, - chunk_workspace, - chunk_topk_indices_workspace, - ) - - return attn_out - - def _forward_fp8_kv_mixed_batch( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, - attn_metadata: FlashMLASparseMetadata, - ) -> torch.Tensor: - """Mixed batch FP8 forward path that treats all tokens as one batch. - - This is equivalent to main branch's approach and avoids the BF16 - prefill kernel which has head padding overhead when num_heads is small. - Used when use_mixed_batch is True. - """ - # Convert per-request indices to global slots (decode) or workspace - # offsets (prefill). - topk_indices = triton_convert_req_index_to_global_index( - attn_metadata.req_id_per_token, - attn_metadata.block_table, - topk_indices, - BLOCK_SIZE=attn_metadata.block_size, - NUM_TOPK_TOKENS=topk_indices.shape[1], - ) - - assert attn_metadata.fp8_extra_metadata is not None - assert isinstance( - attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata - ) - fp8_metadata = attn_metadata.fp8_extra_metadata - - _attn_out, _ = self._fp8_flash_mla_kernel( - q=q.unsqueeze(0), # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D) - kv_c_and_k_pe_cache=kv_c_and_k_pe_cache, - topk_indices=topk_indices.unsqueeze(0), # (T, topk) -> (1, T, topk) - kernel_metadata=fp8_metadata, - ) - - # Output is (1, T, H, D_v), squeeze back to (T, H, D_v) - return _attn_out.squeeze(0) - - def _fp8_flash_mla_kernel( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, - kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata, - ) -> torch.Tensor: - return flash_mla_with_kvcache( - q=q, - k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2), - block_table=kernel_metadata.dummy_block_table, - head_dim_v=512, - cache_seqlens=kernel_metadata.cache_lens, - tile_scheduler_metadata=kernel_metadata.scheduler_metadata, - num_splits=kernel_metadata.num_splits, - is_fp8_kvcache=True, - indices=topk_indices, - softmax_scale=self.softmax_scale, - ) - - def _bf16_flash_mla_kernel( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, ) -> torch.Tensor: num_tokens = q.shape[0] kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( @@ -944,6 +445,31 @@ def _bf16_flash_mla_kernel( output = output[:, : self.num_heads, :] return output + def _forward_fp8_kv( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + assert attn_metadata.fp8_extra_metadata is not None + extra_metadata = attn_metadata.fp8_extra_metadata + + _attn_out, _ = flash_mla_with_kvcache( + q=q.unsqueeze(0), # unsqueeze to add batch_dim + k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2), + block_table=extra_metadata.dummy_block_table, + head_dim_v=512, + cache_seqlens=extra_metadata.cache_lens, + tile_scheduler_metadata=extra_metadata.scheduler_metadata, + num_splits=extra_metadata.num_splits, + is_fp8_kvcache=True, + indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim + softmax_scale=self.softmax_scale, + ) + + return _attn_out + def forward( self, layer: AttentionLayer, @@ -951,7 +477,7 @@ def forward( k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, - attn_metadata: FlashMLASparseMetadata | None, + attn_metadata: FlashMLASparseMetadata, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, @@ -967,7 +493,6 @@ def forward( ) if attn_metadata is None: - # Dummy run - no need to allocate buffers # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the # same expert outputs. @@ -980,7 +505,6 @@ def forward( q = q[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - topk_indices = self.topk_indices_buffer[:num_actual_toks] q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) @@ -990,7 +514,16 @@ def forward( # Convert from (N, B, L) to (B, N, L) ql_nope = ql_nope.transpose(0, 1) - use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla" + topk_indices = self.topk_indices_buffer[:num_actual_toks] + + # TODO: handle index / kv_cache correctly + topk_indices_global = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=attn_metadata.topk_tokens, + ) q = torch.cat([ql_nope, q_pe], dim=-1) @@ -1005,15 +538,13 @@ def forward( scale=layer._k_scale, ) - if not use_fp8_cache: - attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata) - elif attn_metadata.fp8_use_mixed_batch: - attn_out = self._forward_fp8_kv_mixed_batch( - q, kv_cache, topk_indices, attn_metadata + if self.kv_cache_dtype != "fp8_ds_mla": + attn_out = self._forward_bf16_kv( + q, kv_cache, topk_indices_global, attn_metadata ) else: - attn_out = self._forward_fp8_kv_separate_prefill_decode( - q, kv_cache, topk_indices, attn_metadata + attn_out = self._forward_fp8_kv( + q, kv_cache, topk_indices_global, attn_metadata ) self._v_up_proj(attn_out, out=output[:num_actual_toks]) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index d0696f60a08c..77f1ba00d5b0 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -18,7 +18,6 @@ AttentionMetadataBuilder, CommonAttentionMetadata, split_decodes_and_prefills, - split_prefill_chunks, ) logger = init_logger(__name__) @@ -177,15 +176,40 @@ def kv_spans_from_batches( def get_max_prefill_buffer_size(vllm_config: VllmConfig): max_model_len = vllm_config.model_config.max_model_len - # NOTE(Chen): 40 is a magic number for controlling the prefill buffer size. - # Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes. - # The flashmla_sparse backend uses a workspace size of 5 * max_model_len. - # The memory usage of the workspace there is 576 * 2 bytes; so we size this as - # (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting - # within the flashmla_sparse workspace. - # For DeepSeek-V3.2, the max_model_len is 163840. - # 40 * 163840 * 132 = 865075200 bytes = 825 MB - return max_model_len * 40 + # NOTE(Chen): 2 is a magic number for controlling the prefill buffer size. + # May be tuned later. + return max_model_len * 2 + + +def split_prefill_chunks( + seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int +) -> list[tuple[int, int]]: + """ + Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) + such that the total sequence length of each chunk is less than the + maximum prefill buffer size. + + Args: + seq_lens_cpu: The sequence lengths of the prefill requests. + max_prefill_buffer_size: The maximum prefill buffer size. + reqs_start: The start index of the prefill requests. + + Returns: + A list of tuples of (reqs_start, reqs_end). + """ + chunk_seq_ids = [] + total_seq_lens = 0 + for i in range(reqs_start, len(seq_lens_cpu)): + cur_seq_len = seq_lens_cpu[i].item() + assert cur_seq_len <= max_prefill_buffer_size + total_seq_lens += cur_seq_len + if total_seq_lens > max_prefill_buffer_size: + chunk_seq_ids.append((reqs_start, i)) + reqs_start = i + total_seq_lens = cur_seq_len + if total_seq_lens > 0: + chunk_seq_ids.append((reqs_start, len(seq_lens_cpu))) + return chunk_seq_ids class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): @@ -278,9 +302,9 @@ def build( prefill_metadata = None if num_prefills > 0: chunk_seq_ids = split_prefill_chunks( - common_attn_metadata.seq_lens_cpu[num_decodes:], + common_attn_metadata.seq_lens_cpu, self.max_prefill_buffer_size, - request_offset=num_decodes, + num_decodes, ) chunks = [ self.build_one_prefill_chunk( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index da43d8703823..79a1f7d4757d 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -937,33 +937,6 @@ def split_decodes_and_prefills( return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) -def split_prefill_chunks( - seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0 -) -> list[tuple[int, int]]: - """ - Split the prefill requests into chunks such that the total sequence length - of each chunk is less than or equal to the workspace size. - - Args: - seq_lens_cpu: The sequence lengths of the prefill requests on CPU. - workspace_size: The maximum workspace size (in tokens) per chunk. - request_offset: The offset to add to the request indices. - Returns: - A list of tuples of (reqs_start, reqs_end) representing chunk boundaries. - """ - chunk_bounds = [] - i, n = 0, len(seq_lens_cpu) - assert torch.all(seq_lens_cpu <= workspace_size).item() - - while i < n: - start, chunk_total = i, 0 - while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size: - chunk_total += s - i += 1 - chunk_bounds.append((start + request_offset, i + request_offset)) - return chunk_bounds - - def reorder_batch_to_split_decodes_and_prefills( input_batch: "InputBatch", scheduler_output: "SchedulerOutput", diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 978224faae65..3f20296c27ba 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -162,7 +162,6 @@ maybe_create_ubatch_slices, ) from vllm.v1.worker.utils import is_residual_scattered_for_sp -from vllm.v1.worker.workspace import lock_workspace from .utils import ( AttentionGroup, @@ -298,7 +297,6 @@ def __init__( self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype - self.kv_cache_dtype = kv_cache_dtype_str_to_dtype( cache_config.cache_dtype, self.model_config ) @@ -4599,10 +4597,6 @@ def freeze_gc(): # after here. set_cudagraph_capturing_enabled(False) - # Lock workspace to prevent resizing during execution. - # Max workspace sizes should have been captured during warmup/profiling. - lock_workspace() - end_time = time.perf_counter() elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 21a8564f83c4..25ac5aaf9981 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -54,7 +54,6 @@ from vllm.v1.utils import report_usage_stats from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase -from vllm.v1.worker.workspace import init_workspace_manager logger = init_logger(__name__) @@ -256,10 +255,6 @@ def init_device(self): else: raise RuntimeError(f"Not support device type: {self.device_config.device}") - # Initialize workspace manager - num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 - init_workspace_manager(self.device, num_ubatches) - # Construct the model runner if self.use_v2_model_runner: from vllm.v1.worker.gpu.model_runner import ( diff --git a/vllm/v1/worker/workspace.py b/vllm/v1/worker/workspace.py deleted file mode 100644 index a16dde1f6780..000000000000 --- a/vllm/v1/worker/workspace.py +++ /dev/null @@ -1,245 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import inspect -import os -from itertools import accumulate -from math import prod -from typing import Optional - -import torch - -import vllm.envs as envs -from vllm.logger import init_logger -from vllm.utils.math_utils import round_up -from vllm.v1.worker.ubatching import dbo_current_ubatch_id - -logger = init_logger(__name__) - - -def _compute_bytes(shape: tuple[int, ...], dtype: torch.dtype) -> int: - return prod(shape) * dtype.itemsize - - -# Constants -_MB = 1024**2 -_GiB = 1024**3 - -# Global workspace manager instance -_manager: Optional["WorkspaceManager"] = None - - -class WorkspaceManager: - """Manager for workspace allocation. - - Manages workspace buffers for DBO (Dual Batch Overlap) execution. - Can be locked to prevent further growth during execution. - """ - - def __init__(self, device: torch.device, num_ubatches: int | None = None): - self._device = device - # Cache num ubatches at init based on configuration (default to 1) - self._num_ubatches = num_ubatches if num_ubatches is not None else 1 - self._current_workspaces: list[torch.Tensor | None] = [None, None] - self._locked: bool = False - - @staticmethod - def _workspace_size_bytes(workspace: torch.Tensor | None) -> int: - """Get size of workspace in bytes.""" - if workspace is None: - return 0 - return workspace.numel() * workspace.element_size() - - def lock(self) -> None: - """Lock the workspace to prevent further growth. - - After locking, any attempt to allocate a larger workspace will raise - an assertion error. This ensures workspace size is fixed during execution. - """ - self._locked = True - if envs.VLLM_DEBUG_WORKSPACE: - logger.info( - "[WORKSPACE DEBUG] Workspace locked. Current sizes: %s", - [ - self._workspace_size_bytes(ws) / _MB - for ws in self._current_workspaces - if ws is not None - ], - ) - - def is_locked(self) -> bool: - """Check if workspace is locked.""" - return self._locked - - def get_simultaneous( - self, *shapes_and_dtypes: tuple[tuple[int, ...], torch.dtype] - ) -> list[torch.Tensor]: - """Get multiple workspace tensors simultaneously from a single allocation. - - Args: - *shapes_and_dtypes: One or more (shape, dtype) tuples. - - Returns: - List of tensor views into the workspace buffer, one per shape/dtype pair. - """ - actual_bytes = [_compute_bytes(s, d) for s, d in shapes_and_dtypes] - aligned_bytes = [round_up(actual, 256) for actual in actual_bytes] - total_bytes = sum(aligned_bytes) - - # Calculate cumulative offsets using itertools.accumulate - offsets = list(accumulate([0] + aligned_bytes[:-1])) - - current_workspace = self._ensure_workspace_size(total_bytes) - - return [ - current_workspace[offsets[i] : offsets[i] + actual_bytes[i]] - .view(shapes_and_dtypes[i][1]) - .reshape(shapes_and_dtypes[i][0]) - for i in range(len(shapes_and_dtypes)) - ] - - def _ensure_workspace_size(self, required_bytes: int) -> torch.Tensor: - """Ensure workspace is allocated and large enough, return current workspace. - - Args: - required_bytes: The number of bytes required. - - Returns: - The current workspace tensor. - """ - ubatch_id = dbo_current_ubatch_id() - current_workspace = self._current_workspaces[ubatch_id] - current_size = self._workspace_size_bytes(current_workspace) - - if current_size < required_bytes: - - def get_caller_info() -> str: - """Find first frame outside WorkspaceManager.""" - curr_frame = inspect.currentframe() - if curr_frame is None: - return "unknown" - # Walk up the stack skipping WorkspaceManager frames - curr_frame = curr_frame.f_back - while curr_frame is not None: - # TODO: This only catches instance methods (self), missing - # classmethods and staticmethods. Once Python 3.11+ is the - # minimum supported version, use co_qualname instead: - # qualname = curr_frame.f_code.co_qualname - # if qualname.startswith("WorkspaceManager."): - if isinstance(curr_frame.f_locals.get("self"), WorkspaceManager): - curr_frame = curr_frame.f_back - continue - filename = os.path.basename(curr_frame.f_code.co_filename) - return ( - f"{filename}:{curr_frame.f_lineno}:{curr_frame.f_code.co_name}" - ) - return "unknown" - - if self._locked: - raise AssertionError( - f"Workspace is locked but allocation from '{get_caller_info()}' " - f"requires {required_bytes / _MB:.2f} MB, current size is " - f"{current_size / _MB:.2f} MB. " - "Workspace growth is not allowed after locking." - ) - - for ubatch_id in range(self._num_ubatches): - current_workspace = self._current_workspaces[ubatch_id] - if current_workspace is None: - self._current_workspaces[ubatch_id] = torch.empty( - (required_bytes,), dtype=torch.uint8, device=self._device - ) - elif self._workspace_size_bytes(current_workspace) < required_bytes: - current_workspace.resize_(required_bytes) - - if envs.VLLM_DEBUG_WORKSPACE: - logger.info( - "[WORKSPACE DEBUG] Resized workspace from '%s': %.2f MB -> " - "%.2f MB (%d ubatches, total memory %.2f MB)", - get_caller_info(), - current_size / _MB, - required_bytes / _MB, - self._num_ubatches, - required_bytes * self._num_ubatches / _MB, - ) - - current_workspace = self._current_workspaces[dbo_current_ubatch_id()] - - return current_workspace - - -def is_workspace_manager_initialized() -> bool: - """Check if workspace manager has been initialized. - - Returns: - True if workspace manager is initialized, False otherwise. - """ - return _manager is not None - - -def current_workspace_manager() -> "WorkspaceManager": - """Get the current workspace manager instance. - - Raises: - AssertionError: If workspace manager has not been initialized. - """ - assert _manager is not None, ( - "WorkspaceManager not initialized. Call init_workspace_manager() " - "with a device before using workspace functions." - ) - return _manager - - -def init_workspace_manager( - device: torch.device, num_ubatches: int | None = None -) -> None: - """Initialize the workspace manager with a device. - - Must be called before using any workspace functions. Typically called - from GPUModelRunner.__init__. - - Args: - device: The device to allocate workspace on. - num_ubatches: Number of micro-batches. Defaults to 1. - """ - global _manager - if _manager is not None: - logger.warning( - "WorkspaceManager already initialized on device %s, " - "reinitializing on device %s", - _manager._device, - device, - ) - _manager = WorkspaceManager(device, num_ubatches) - - -def lock_workspace() -> None: - """Lock the workspace to prevent further growth. - - After calling this function, any attempt to allocate a workspace larger - than the current size will raise an AssertionError. This ensures that - workspace size is fixed during execution and prevents unexpected memory - allocations in the hot path. - - Example: - # During initialization - init_workspace_manager(device) - reserve_workspace(shape1, dtype1) - reserve_workspace(shape2, dtype2) - - # Lock after warmup/profiling - lock_workspace() - - # Now all get_workspace calls must fit in pre-allocated size - """ - current_workspace_manager().lock() - - -def reset_workspace_manager() -> None: - """Reset the workspace manager to uninitialized state. - - This is primarily intended for testing purposes to allow tests - to reinitialize the workspace manager cleanly. - """ - global _manager - _manager = None