Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions csrc/cache.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include <torch/all.h>
#include <c10/util/Optional.h>

#include <map>
#include <vector>
Expand Down Expand Up @@ -59,15 +58,6 @@ void cp_gather_cache(
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);

// Gather and upconvert FP8 KV cache to BF16 workspace
void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size);

// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
Expand All @@ -82,4 +72,4 @@ void cp_gather_indexer_k_quant_cache(
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
131 changes: 1 addition & 130 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAException.h>
#include <c10/util/Optional.h>

#include "cuda_utils.h"
#include "cuda_compat.h"
Expand Down Expand Up @@ -515,8 +514,7 @@ __global__ void indexer_k_quant_and_cache_kernel(
const int quant_block_size, // quantization block size
const int cache_block_size, // cache block size
const int cache_stride, // stride for each token in kv_cache

const bool use_ue8m0 // use ue8m0 scale format
const bool use_ue8m0 // use ue8m0 scale format
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
Expand Down Expand Up @@ -1063,82 +1061,6 @@ void gather_and_maybe_dequant_cache(
}

namespace vllm {

// Gather and upconvert FP8 KV cache tokens to BF16 workspace
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ seq_lens, // [BATCH]
const int32_t* __restrict__ workspace_starts, // [BATCH]
const int32_t block_size, const int32_t head_dim,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = workspace_starts[bid];
const int32_t seq_len = seq_lens[bid];
const int32_t tot_slots = seq_len;
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);

const int32_t split_start = split * split_slots;
const int32_t split_end = min((split + 1) * split_slots, tot_slots);

const bool is_active_split = (split_start < tot_slots);

if (!is_active_split) return;

// Adjust the pointer for the block_table for this batch
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = split_start;
int32_t offset_div = offset / block_size;
offset = offset % block_size;
const int32_t* batch_block_table = block_table + batch_offset;

// Adjust dst pointer based on the cumulative sequence lengths
dst += seq_start * dst_entry_stride;

const int tid = threadIdx.x;

// Process each token in this split
for (int pid = split_start; pid < split_end; ++pid) {
auto block_id = batch_block_table[offset_div];
const uint8_t* token_ptr =
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;

// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
const uint8_t* no_pe_ptr = token_ptr;
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const __nv_bfloat16* rope_ptr =
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);

// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
if (tid < 512) {
// FP8 dequantization
const int tile = tid >> 7; // each tile is 128 elements
const float scale = scales_ptr[tile];
const uint8_t val = no_pe_ptr[tid];
dst_ptr[tid] =
fp8::scaled_convert<__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
} else if (tid < 576) {
// Rope copy (64 bf16 elements)
const int rope_idx = tid - 512;
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
}

// Move to next token
offset += 1;
if (offset == block_size) {
offset_div += 1;
offset = 0;
}
}
}

template <typename scalar_t>
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
// block_size.
Expand Down Expand Up @@ -1280,57 +1202,6 @@ void cp_gather_cache(
}
}

void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

int32_t block_size = src_cache.size(1);
int32_t head_dim = dst.size(1);

TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
"workspace_starts must be int32");

TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == seq_lens.device(),
"src_cache and seq_lens must be on the same device");
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
"src_cache and workspace_starts must be on the same device");

TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8");
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");

int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);

// Decide on the number of splits based on the batch size
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(576);

vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
src_cache.data_ptr<uint8_t>(),
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
block_table_stride, cache_block_stride, cache_entry_stride,
dst_entry_stride);
}

// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
Expand Down
7 changes: 0 additions & 7 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,13 +754,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);

cache_ops.def(
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
"batch_size) -> ()");
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
&cp_gather_and_upconvert_fp8_kv_cache);

cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
Expand Down
21 changes: 0 additions & 21 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,27 +202,6 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
cleanup_dist_env_and_memory()


@pytest.fixture
def workspace_init():
"""Initialize the workspace manager for tests that need it.

This fixture initializes the workspace manager with a CUDA device
if available, and resets it after the test completes. Tests that
create a full vLLM engine should NOT use this fixture as the engine
will initialize the workspace manager itself.
"""
from vllm.v1.worker.workspace import (
init_workspace_manager,
reset_workspace_manager,
)

if torch.cuda.is_available():
device = torch.device("cuda:0")
init_workspace_manager(device)
yield
reset_workspace_manager()


@pytest.fixture(autouse=True)
def dynamo_reset():
yield
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/moe/test_batched_deepgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
@pytest.mark.parametrize("topk", [2, 4])
def test_batched_deepgemm_vs_triton(
E: int, T: int, K: int, N: int, topk: int, monkeypatch, workspace_init
E: int, T: int, K: int, N: int, topk: int, monkeypatch
):
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""

Expand Down
1 change: 0 additions & 1 deletion tests/kernels/moe/test_batched_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ def test_fused_moe_batched_experts(
per_act_token_quant: bool,
block_shape: list[int] | None,
input_scales: bool,
workspace_init,
):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware."""
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/moe/test_block_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def setup_cuda():
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(
M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init
M, N, K, E, topk, block_size, dtype, seed, monkeypatch
):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")
Expand Down
27 changes: 2 additions & 25 deletions tests/kernels/moe/test_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ def test_cutlass_moe_8_bit_no_graph(
per_act_token: bool,
per_out_ch: bool,
monkeypatch,
workspace_init,
ep_size: int | None = None,
):
current_platform.seed_everything(7)
Expand Down Expand Up @@ -330,7 +329,6 @@ def test_cutlass_moe_8_bit_cuda_graph(
per_act_token: bool,
per_out_ch: bool,
monkeypatch,
workspace_init,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
Expand Down Expand Up @@ -387,19 +385,9 @@ def test_cutlass_moe_8_bit_EP(
per_out_channel: bool,
ep_size: int,
monkeypatch,
workspace_init,
):
test_cutlass_moe_8_bit_no_graph(
m,
n,
k,
e,
topk,
per_act_token,
per_out_channel,
monkeypatch,
workspace_init,
ep_size,
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
)


Expand Down Expand Up @@ -431,19 +419,9 @@ def test_cutlass_moe_8_bit_EP_large(
per_out_channel: bool,
ep_size: int,
monkeypatch,
workspace_init,
):
test_cutlass_moe_8_bit_no_graph(
m,
n,
k,
e,
topk,
per_act_token,
per_out_channel,
monkeypatch,
workspace_init,
ep_size,
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
)


Expand All @@ -467,7 +445,6 @@ def test_run_cutlass_moe_fp8(
per_act_token: bool,
per_out_channel: bool,
ep_size: int,
workspace_init,
):
current_platform.seed_everything(7)
with set_current_vllm_config(vllm_config):
Expand Down
6 changes: 0 additions & 6 deletions tests/kernels/moe/test_deepep_deepgemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
is_deep_gemm_supported,
)
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
from vllm.v1.worker.workspace import init_workspace_manager

from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
Expand Down Expand Up @@ -364,9 +363,6 @@ def _test_deepep_deepgemm_moe(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
):
device = torch.device(f"cuda:{pgi.local_rank}")
init_workspace_manager(device)

current_platform.seed_everything(pgi.rank)

w1 = w1.to(device=torch.cuda.current_device())
Expand Down Expand Up @@ -449,7 +445,6 @@ def test_ht_deepep_deepgemm_moe(
topk: int,
world_dp_size: tuple[int, int],
disable_deepgemm_ue8m0,
workspace_init,
):
"""
Tests for High-Throughput DeepEP + DeepGemm integration.
Expand Down Expand Up @@ -523,7 +518,6 @@ def test_ll_deepep_deepgemm_moe(
block_size: list[int],
world_dp_size: tuple[int, int],
disable_deepgemm_ue8m0,
workspace_init,
):
"""
Tests for Low-Latency DeepEP + DeepGemm integration.
Expand Down
6 changes: 0 additions & 6 deletions tests/kernels/moe/test_deepep_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep
from vllm.v1.worker.workspace import init_workspace_manager

from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
Expand Down Expand Up @@ -343,9 +342,6 @@ def _deep_ep_moe(
use_fp8_dispatch: bool,
per_act_token_quant: bool,
):
device = torch.device(f"cuda:{pgi.local_rank}")
init_workspace_manager(device)

if not low_latency_mode:
assert not use_fp8_dispatch, (
"FP8 dispatch interface is available only in low-latency mode"
Expand Down Expand Up @@ -441,7 +437,6 @@ def test_deep_ep_moe(
topk: int,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
workspace_init,
):
low_latency_mode = False
use_fp8_dispatch = False
Expand Down Expand Up @@ -497,7 +492,6 @@ def test_low_latency_deep_ep_moe(
topk: int,
world_dp_size: tuple[int, int],
use_fp8_dispatch: bool,
workspace_init,
):
low_latency_mode = True

Expand Down
Loading
Loading