diff --git a/csrc/ops.h b/csrc/ops.h index 37e3aaf7499d..f52ae73066bd 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -301,6 +301,12 @@ void scaled_fp4_experts_quant( torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts); +void silu_and_mul_scaled_fp4_experts_quant( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); + void per_token_group_quant_fp8(const torch::Tensor& input, torch::Tensor& output_q, torch::Tensor& output_s, int64_t group_size, double eps, double fp8_min, diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index e0438556dfe5..2ea229c47d7e 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -31,37 +31,6 @@ namespace vllm { -// silu in float32 -__device__ __forceinline__ float silu(float x) { - return __fdividef(x, (1.f + __expf(-x))); -} - -__device__ __forceinline__ float2 silu2(float2 x) { - return make_float2(silu(x.x), silu(x.y)); -} - -template -__inline__ __device__ PackedVec compute_silu_mul(PackedVec& vec, - PackedVec& vec2) { - PackedVec result; - using packed_type = typename TypeConverter::Type; - -#pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { - // silu_mul in float32 - if constexpr (std::is_same_v) { - float2 silu_vec = silu2(__half22float2(vec.elts[i])); - result.elts[i] = - __float22half2_rn(__fmul2_rn(silu_vec, __half22float2(vec2.elts[i]))); - } else { - float2 silu_vec = silu2(__bfloat1622float2(vec.elts[i])); - result.elts[i] = __float22bfloat162_rn( - __fmul2_rn(silu_vec, __bfloat1622float2(vec2.elts[i]))); - } - } - return result; -} - // Use UE4M3 by default. template __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 20191a9bc616..aa573c007b3d 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -31,8 +31,12 @@ namespace vllm { +// NVFP4 quantization kernel for experts (low-latency path). +// When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses +// SiLU(gate)*up before quantization. // Use UE4M3 by default. -template +template __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout, @@ -50,6 +54,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) int tid = blockIdx.x * blockDim.x + threadIdx.x; int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; + // When fusing SiLU+Mul, input has gate || up layout (doubled width) + int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow; // Each global thread processes one element for (int globalIdx = tid; globalIdx < numRows * colsPerRow; @@ -58,13 +64,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) int rowIdx = globalIdx / colsPerRow; int colIdx = globalIdx % colsPerRow; - int64_t inOffset = rowIdx * colsPerRow + colIdx; - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - // Get the output tensor offset. - // Same as inOffset because 8 elements are packed into one uint32_t. - int64_t outOffset = inOffset; - auto& out_pos = out[outOffset]; - // Find index within the experts using different strategies based on expert // count int rowIdx_in_expert = 0; @@ -111,6 +110,23 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) } } + // Load input and optionally apply fused SiLU+Mul + int64_t inOffset = rowIdx * inColsPerRow + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVec quant_input; + if constexpr (FUSE_SILU_MUL) { + PackedVec in_vec_up = + reinterpret_cast(in)[inOffset + colsPerRow]; + quant_input = compute_silu_mul(in_vec, in_vec_up); + } else { + quant_input = in_vec; + } + + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = rowIdx * colsPerRow + colIdx; + auto& out_pos = out[outOffset]; + // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is // (448.f / (Alpha_A / 6.f)). @@ -124,12 +140,16 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) CVT_FP4_NUM_THREADS_PER_SF>( rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); - out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + out_pos = + cvt_warp_fp16_to_fp4(quant_input, SFScaleVal, sf_out); } } -// Kernel for LARGE_M_TOPK = true (large m_topk optimized version) -template +// NVFP4 quantization kernel for LARGE_M_TOPK = true (large m_topk optimized +// version). When FUSE_SILU_MUL=true, expects input with gate||up layout and +// fuses SiLU(gate)*up before quantization. +template __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout, @@ -167,6 +187,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) int tid = blockIdx.x * blockDim.x + threadIdx.x; int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; + // When fusing SiLU+Mul, input has gate || up layout (doubled width) + int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow; // Each global thread processes one element for (int globalIdx = tid; globalIdx < numRows * colsPerRow; @@ -175,11 +197,6 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) int rowIdx = globalIdx / colsPerRow; int colIdx = globalIdx % colsPerRow; - int64_t inOffset = rowIdx * colsPerRow + colIdx; - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - int64_t outOffset = inOffset; - auto& out_pos = out[outOffset]; - // Find expert using binary search for better performance with large m_topk int rowIdx_in_expert = 0; int expert_idx = 0; @@ -204,6 +221,21 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) } } + // Load input and optionally apply fused SiLU+Mul + int64_t inOffset = rowIdx * inColsPerRow + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVec quant_input; + if constexpr (FUSE_SILU_MUL) { + PackedVec in_vec_up = + reinterpret_cast(in)[inOffset + colsPerRow]; + quant_input = compute_silu_mul(in_vec, in_vec_up); + } else { + quant_input = in_vec; + } + + int64_t outOffset = rowIdx * colsPerRow + colIdx; + auto& out_pos = out[outOffset]; + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; uint32_t* SFout_in_expert = @@ -214,11 +246,12 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) CVT_FP4_NUM_THREADS_PER_SF>( rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); - out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + out_pos = + cvt_warp_fp16_to_fp4(quant_input, SFScaleVal, sf_out); } } -template +template void quant_impl(void* output, void* output_scale, void* input, void* input_global_scale, void* input_offset_by_experts, void* output_scale_offset_by_experts, int m_topk, int k, @@ -246,7 +279,7 @@ void quant_impl(void* output, void* output_scale, void* input, if (blockRepeat > 1) { size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t); if (n_experts >= 4) { - cvt_fp16_to_fp4 + cvt_fp16_to_fp4 <<>>( m_topk, k, reinterpret_cast(input), reinterpret_cast(input_global_scale), @@ -256,34 +289,37 @@ void quant_impl(void* output, void* output_scale, void* input, reinterpret_cast(output_scale_offset_by_experts), n_experts); } else { - cvt_fp16_to_fp4<<>>( - m_topk, k, reinterpret_cast(input), - reinterpret_cast(input_global_scale), - reinterpret_cast(output), - reinterpret_cast(output_scale), - reinterpret_cast(input_offset_by_experts), - reinterpret_cast(output_scale_offset_by_experts), - n_experts); + cvt_fp16_to_fp4 + <<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + n_experts); } } else { if (n_experts >= 16) { - cvt_fp16_to_fp4<<>>( - m_topk, k, reinterpret_cast(input), - reinterpret_cast(input_global_scale), - reinterpret_cast(output), - reinterpret_cast(output_scale), - reinterpret_cast(input_offset_by_experts), - reinterpret_cast(output_scale_offset_by_experts), - n_experts, /* bool low_latency */ true); + cvt_fp16_to_fp4 + <<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + n_experts, /* bool low_latency */ true); } else { - cvt_fp16_to_fp4<<>>( - m_topk, k, reinterpret_cast(input), - reinterpret_cast(input_global_scale), - reinterpret_cast(output), - reinterpret_cast(output_scale), - reinterpret_cast(input_offset_by_experts), - reinterpret_cast(output_scale_offset_by_experts), - n_experts, /* bool low_latency */ true); + cvt_fp16_to_fp4 + <<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + n_experts, /* bool low_latency */ true); } } } @@ -304,19 +340,19 @@ constexpr auto FLOAT = at::ScalarType::Float; constexpr auto INT = at::ScalarType::Int; constexpr auto UINT8 = at::ScalarType::Byte; -void scaled_fp4_experts_quant_sm1xxa( - torch::Tensor& output, torch::Tensor& output_scale, +// Common validation for fp4 experts quantization entry points. +static void validate_fp4_experts_quant_inputs( + torch::Tensor const& output, torch::Tensor const& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts) { - CHECK_INPUT(output, "output must be a CUDA tensor"); - CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor"); - CHECK_INPUT(input, "input must be a CUDA tensor"); - CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor"); - CHECK_INPUT(input_offset_by_experts, - "input_offset_by_experts must be a CUDA tensor"); - CHECK_INPUT(output_scale_offset_by_experts, - "output_scale_offset_by_experts must be a CUDA tensor"); + torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk, + int64_t k) { + CHECK_INPUT(output, "output"); + CHECK_INPUT(output_scale, "output_scale"); + CHECK_INPUT(input, "input"); + CHECK_INPUT(input_global_scale, "input_global_scale"); + CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts"); + CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts"); TORCH_CHECK(output.dim() == 2); TORCH_CHECK(output_scale.dim() == 2); @@ -335,8 +371,6 @@ void scaled_fp4_experts_quant_sm1xxa( TORCH_CHECK(output_scale.scalar_type() == INT); const int BLOCK_SIZE = 16; - auto m_topk = input.size(0); - auto k = input.size(1); TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); auto n_experts = input_global_scale.size(0); TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); @@ -348,7 +382,21 @@ void scaled_fp4_experts_quant_sm1xxa( int padded_k = (scales_k + (4 - 1)) / 4 * 4; // 4 means 4 fp8 values are packed into one int32 TORCH_CHECK(output_scale.size(1) * 4 == padded_k); +} + +void scaled_fp4_experts_quant_sm1xxa( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { + auto m_topk = input.size(0); + auto k = input.size(1); + + validate_fp4_experts_quant_inputs(output, output_scale, input, + input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts, m_topk, k); + auto n_experts = input_global_scale.size(0); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); @@ -356,7 +404,38 @@ void scaled_fp4_experts_quant_sm1xxa( VLLM_DISPATCH_HALF_TYPES( input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { using cuda_type = vllm::CUDATypeConverter::Type; - vllm::quant_impl( + vllm::quant_impl( + output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), + input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, + stream); + }); +} + +void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { + auto m_topk = input.size(0); + // Input has gate || up layout, so k = input.size(1) / 2 + auto k_times_2 = input.size(1); + TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)"); + auto k = k_times_2 / 2; + + validate_fp4_experts_quant_inputs(output, output_scale, input, + input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts, m_topk, k); + + auto n_experts = input_global_scale.size(0); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(input.get_device()); + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + vllm::quant_impl( output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index fb6d22f035b9..25e0ba8486c7 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -41,6 +41,15 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, torch::Tensor& input_sf); #endif +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); +#endif + void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ @@ -74,3 +83,18 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf, TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled silu_and_mul nvfp4 quantization kernel"); } + +void silu_and_mul_scaled_fp4_experts_quant( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + return silu_and_mul_scaled_fp4_experts_quant_sm1xxa( + output, output_scale, input, input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts); +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, "No compiled silu_and_mul nvfp4 experts quantization kernel"); +} diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index 4c91af85e151..7082ad684bc3 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -239,4 +239,34 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, return e2m1Vec; } +// silu in float32 +__device__ __forceinline__ float silu(float x) { + return __fdividef(x, (1.f + __expf(-x))); +} + +__device__ __forceinline__ float2 silu2(float2 x) { + return make_float2(silu(x.x), silu(x.y)); +} + +template +__inline__ __device__ PackedVec compute_silu_mul( + const PackedVec& x_vec, const PackedVec& y_vec) { + PackedVec result; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { + // silu_mul in float32 + if constexpr (std::is_same_v) { + float2 silu_vec = silu2(__half22float2(x_vec.elts[i])); + result.elts[i] = __float22half2_rn( + __fmul2_rn(silu_vec, __half22float2(y_vec.elts[i]))); + } else { + float2 silu_vec = silu2(__bfloat1622float2(x_vec.elts[i])); + result.elts[i] = __float22bfloat162_rn( + __fmul2_rn(silu_vec, __bfloat1622float2(y_vec.elts[i]))); + } + } + return result; +} + } // namespace vllm diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 6f2c8e915b5c..fb29e0cd8b75 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -558,6 +558,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor output_scale_offset_by_experts) -> ()"); ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant); + // Fused SiLU+Mul+NVFP4 experts quantization. + ops.def( + "silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! " + "output_scale," + "Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," + "Tensor output_scale_offset_by_experts) -> ()"); + ops.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA, + &silu_and_mul_scaled_fp4_experts_quant); + // Check if cutlass_scaled_mm_fp4 is supported for CUDA devices // of the given capability ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool"); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0d6d545fed51..0d18cab2c960 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1606,15 +1606,15 @@ def scaled_fp4_experts_quant( topk: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Quantize input tensor to FP4 and return quantized tensor and scale, for + Quantize input tensor to NVFP4 and return quantized tensor and scale, for packed MoE Inputs. Args: - input_tensor: The input tensor to be quantized to FP4 + input_tensor: The input tensor to be quantized to NVFP4 input_global_scale: A scalar scaling factor for the entire tensor. expert_offsets: The expert offsets tensor blockscale_offsets: The blockscale offsets tensor Outputs: - output: The quantized tensor in FP4 + output: The quantized tensor in NVFP4 output_scales: The blockscale tensor in FP8-E4M3 """ assert not current_platform.is_rocm() @@ -1660,6 +1660,71 @@ def scaled_fp4_experts_quant( return output, output_scales +def silu_and_mul_scaled_fp4_experts_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, + topk: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Fused SiLU+Mul+NVFP4 quantization for MoE intermediate activations. + + Args: + input_tensor: The input tensor with gate || up layout [m_topk, k*2] + input_global_scale: A per-expert scaling factor [n_experts] + expert_offsets: The expert offsets tensor [n_experts+1] + blockscale_offsets: The blockscale offsets tensor [n_experts+1] + topk: Number of top-k experts selected + Outputs: + output: The quantized tensor in NVFP4 [m_topk, k/2] + output_scales: The blockscale tensor in FP8-E4M3 + """ + assert not current_platform.is_rocm() + assert input_tensor.ndim == 2, ( + f"input.ndim needs to be == 2, but got {input_tensor.ndim}." + ) + + # Control the maximum number of tokens per expert supported by the + # NVFP4 MoE Expert Quantization. This is used to prevent the kernel + # from running out of memory. This value can also be increased to support + # larger models. + MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE + m_numtopk, k_times_2 = input_tensor.shape + assert k_times_2 % 2 == 0, "input width must be even (gate || up layout)" + k = k_times_2 // 2 + + assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, ( + f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" + f"{MAX_TOKENS_PER_EXPERT})" + f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use" + f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value." + ) + scales_k = k // 16 + padded_k = (scales_k + (4 - 1)) // 4 + + # output is uint8 and packed fp4 values + output = torch.empty( + m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8 + ) + output_scales = torch.empty( + MAX_TOKENS_PER_EXPERT * topk, + padded_k, + dtype=torch.int32, + device=input_tensor.device, + ) + torch.ops._C.silu_and_mul_scaled_fp4_experts_quant( + output, + output_scales, + input_tensor, + input_global_scale, + expert_offsets, + blockscale_offsets, + ) + output_scales = output_scales.view(torch.float8_e4m3fn) + return output, output_scales + + # fp8 def scaled_fp8_quant( input: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 32ea040c743c..d46a568f343f 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -549,7 +549,8 @@ def run_cutlass_moe_fp4( num_topk, ) c1 = _resize_cache(workspace13, (m * topk, n * 2)) - c2 = _resize_cache(workspace2, (m * topk, n)) + # Note: c2 workspace is no longer needed since SiLU is fused with quantization. + # c3 reuses workspace13 after c1 is consumed. c3 = _resize_cache(workspace13, (m * topk, k)) ops.cutlass_fp4_moe_mm( c1, @@ -563,9 +564,9 @@ def run_cutlass_moe_fp4( blockscale_offsets[:-1], ) del rep_a_fp4, rep_a_blockscale - torch.ops._C.silu_and_mul(c2, c1) - int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( - c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk + # Fused SiLU+Mul+NVFP4 quantization + int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant( + c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk ) ops.cutlass_fp4_moe_mm(