diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 24341d63fb1f..57ff5059f35b 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -8,7 +8,7 @@ #include "../cuda_compat.h" #include "../dispatch_utils.h" -#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) +#define CEILDIV(x, y) (((x) + (y)-1) / (y)) namespace vllm { namespace moe { @@ -221,7 +221,8 @@ __global__ void moe_sum_kernel( void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { + torch::Tensor num_tokens_post_pad, + bool use_global_memory) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // If we have very large number of experts, we can no longer use shared @@ -229,21 +230,20 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, // TODO(simon): the right solution should be calculating the exact right // amount of shared memory and use that. The num_experts >= 256 is just a // temporary solution to unblock Deepseek V3. - if (num_experts >= 256) { + if (use_global_memory) { VLLM_DISPATCH_INTEGRAL_TYPES( topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - const int32_t mem_tokens_cnts = - ((num_experts + 1) * num_experts) * sizeof(int32_t); - const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t); - // allocate global memory - int32_t* tokens_cnts; - int32_t* cumsum; - cudaMalloc(&tokens_cnts, mem_tokens_cnts); - cudaMalloc(&cumsum, mem_cumsum); + auto options_int = torch::TensorOptions() + .dtype(torch::kInt) + .device(topk_ids.device()); + torch::Tensor token_cnts_buffer = + torch::empty({(num_experts + 1) * num_experts}, options_int); + torch::Tensor cumsum_buffer = + torch::empty({num_experts + 1}, options_int); auto kernel = vllm::moe::moe_align_block_size_global_mem_kernel; @@ -252,9 +252,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel(), tokens_cnts, cumsum); - cudaFree(tokens_cnts); - cudaFree(cumsum); + topk_ids.numel(), token_cnts_buffer.data_ptr(), + cumsum_buffer.data_ptr()); }); } else { VLLM_DISPATCH_INTEGRAL_TYPES( diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 596cc0aa6c85..055eae551e8f 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -11,4 +11,5 @@ void moe_sum(torch::Tensor& input, torch::Tensor& output); void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); + torch::Tensor num_tokens_post_pad, + bool use_global_memory); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index f3a558c14ab9..98ac32b1e99b 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -19,7 +19,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "moe_align_block_size(Tensor topk_ids, int num_experts," " int block_size, Tensor! sorted_token_ids," " Tensor! experts_ids," - " Tensor! num_tokens_post_pad) -> ()"); + " Tensor! num_tokens_post_pad," + " bool use_global_memory) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); #ifndef USE_ROCM diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d04cbbc0a9ee..3a3740a321e9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -914,13 +914,17 @@ def moe_sum(input: torch.Tensor, output: torch.Tensor): torch.ops._moe_C.moe_sum(input, output) -def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, - block_size: int, sorted_token_ids: torch.Tensor, +def moe_align_block_size(topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor) -> None: + num_tokens_post_pad: torch.Tensor, + use_global_memory: bool = False) -> None: torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size, sorted_token_ids, experts_ids, - num_tokens_post_pad) + num_tokens_post_pad, + use_global_memory) def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/config.py b/vllm/config.py index 59b509d5a961..f35306c3f516 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -604,7 +604,7 @@ def _verify_cuda_graph(self) -> None: self.max_model_len) if (self.hf_config.model_type == 'deepseek_v3' - and not self.enforce_eager): + and self.quantization == "fp8" and not self.enforce_eager): logger.warning("CUDA graph is not supported for Deepseek V3 yet, " "fallback to the eager mode.") self.enforce_eager = True diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 308c1d6ac6db..4db321ee986e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -256,8 +256,10 @@ def moe_align_block_size( num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + use_global_memory = num_experts >= 256 # for deepseek-v3 ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + expert_ids, num_tokens_post_pad, + use_global_memory) return sorted_ids, expert_ids, num_tokens_post_pad