diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 758a77795553..303433392c32 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -11,29 +11,74 @@ namespace vllm { template + bool act_first, bool HAS_CLAMP> __device__ __forceinline__ scalar_t compute(const scalar_t& x, - const scalar_t& y) { - return act_first ? ACT_FN(x) * y : x * ACT_FN(y); + const scalar_t& y, + const float limit) { + if constexpr (act_first) { + scalar_t gate = x; + scalar_t up = y; + if constexpr (HAS_CLAMP) { + gate = (scalar_t)fminf((float)gate, limit); + up = (scalar_t)fmaxf(fminf((float)up, limit), -limit); + } + return ACT_FN(gate) * up; + } else { + scalar_t gate = x; + scalar_t up = y; + if constexpr (HAS_CLAMP) { + gate = (scalar_t)fmaxf(fminf((float)gate, limit), -limit); + up = (scalar_t)fminf((float)up, limit); + } + return gate * ACT_FN(up); + } } template + bool act_first, bool HAS_CLAMP> __device__ __forceinline__ packed_t packed_compute(const packed_t& x, - const packed_t& y) { - return act_first ? packed_mul(PACKED_ACT_FN(x), y) - : packed_mul(x, PACKED_ACT_FN(y)); + const packed_t& y, + const float limit) { + if constexpr (act_first) { + packed_t gate = x; + packed_t up = y; + if constexpr (HAS_CLAMP) { + float2 g = cast_to_float2(gate); + float2 u = cast_to_float2(up); + g.x = fminf(g.x, limit); + g.y = fminf(g.y, limit); + u.x = fmaxf(fminf(u.x, limit), -limit); + u.y = fmaxf(fminf(u.y, limit), -limit); + gate = cast_to_packed(g); + up = cast_to_packed(u); + } + return packed_mul(PACKED_ACT_FN(gate), up); + } else { + packed_t gate = x; + packed_t up = y; + if constexpr (HAS_CLAMP) { + float2 g = cast_to_float2(gate); + float2 u = cast_to_float2(up); + g.x = fmaxf(fminf(g.x, limit), -limit); + g.y = fmaxf(fminf(g.y, limit), -limit); + u.x = fminf(u.x, limit); + u.y = fminf(u.y, limit); + gate = cast_to_packed(g); + up = cast_to_packed(u); + } + return packed_mul(gate, PACKED_ACT_FN(up)); + } } // Activation and gating kernel template. template + bool use_vec, bool HAS_CLAMP, bool use_256b = false> __global__ void act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] - const int d) { + const int d, const float limit) { const scalar_t* x_ptr = input + blockIdx.x * 2 * d; const scalar_t* y_ptr = x_ptr + d; scalar_t* out_ptr = out + blockIdx.x * d; @@ -58,8 +103,9 @@ __global__ void act_and_mul_kernel( } #pragma unroll for (int j = 0; j < pvec_t::NUM_ELTS; j++) { - x.elts[j] = packed_compute( - x.elts[j], y.elts[j]); + x.elts[j] = + packed_compute( + x.elts[j], y.elts[j], limit); } if constexpr (use_256b) { st256(x, &out_vec[i]); @@ -72,7 +118,8 @@ __global__ void act_and_mul_kernel( for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&x_ptr[idx]); const scalar_t y = VLLM_LDG(&y_ptr[idx]); - out_ptr[idx] = compute(x, y); + out_ptr[idx] = + compute(x, y, limit); } } } @@ -151,8 +198,11 @@ packed_gelu_tanh_kernel(const packed_t& val) { // Launch activation and gating kernel. // Use ACT_FIRST (bool) indicating whether to apply the activation function -// first. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ +// first. HAS_CLAMP (bool) enables pre-activation clamping: gate input is +// clamped (max only) and up input is clamped (both sides) before the +// activation function is applied. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST, \ + HAS_CLAMP, LIMIT) \ auto dtype = input.scalar_type(); \ int d = input.size(-1) / 2; \ int64_t num_tokens = input.numel() / input.size(-1); \ @@ -177,8 +227,8 @@ packed_gelu_tanh_kernel(const packed_t& val) { scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL::Type>, \ - ACT_FIRST, true, true><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ + ACT_FIRST, true, HAS_CLAMP, true><<>>( \ + out.data_ptr(), input.data_ptr(), d, LIMIT); \ }); \ } else { \ VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ @@ -186,8 +236,8 @@ packed_gelu_tanh_kernel(const packed_t& val) { scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL::Type>, \ - ACT_FIRST, true, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ + ACT_FIRST, true, HAS_CLAMP, false><<>>( \ + out.data_ptr(), input.data_ptr(), d, LIMIT); \ }); \ } \ } else { \ @@ -197,8 +247,8 @@ packed_gelu_tanh_kernel(const packed_t& val) { scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL::Type>, \ - ACT_FIRST, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ + ACT_FIRST, false, HAS_CLAMP><<>>( \ + out.data_ptr(), input.data_ptr(), d, LIMIT); \ }); \ } @@ -206,7 +256,14 @@ void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel, - true); + true, false, 0.0f); +} + +void silu_and_mul_clamp(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + double limit) { + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel, + true, true, (float)limit); } void mul_and_silu(torch::Tensor& out, // [..., d] @@ -215,21 +272,21 @@ void mul_and_silu(torch::Tensor& out, // [..., d] // The difference between mul_and_silu and silu_and_mul is that mul_and_silu // applies the silu to the latter half of the input. LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel, - false); + false, false, 0.0f); } void gelu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel, - true); + true, false, 0.0f); } void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, - vllm::packed_gelu_tanh_kernel, true); + LAUNCH_ACTIVATION_GATE_KERNEL( + vllm::gelu_tanh_kernel, vllm::packed_gelu_tanh_kernel, true, false, 0.0f); } namespace vllm { diff --git a/csrc/ops.h b/csrc/ops.h index 821c505b3a02..16a78f570cf6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -163,6 +163,8 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void silu_and_mul_clamp(torch::Tensor& out, torch::Tensor& input, double limit); + void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index ab53f56d142b..8d8f7bed0441 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -106,6 +106,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + // SwiGLU activation with input clamping. + ops.def( + "silu_and_mul_with_clamp(Tensor! result, Tensor input, float limit) " + "-> ()"); + ops.impl("silu_and_mul_with_clamp", torch::kCUDA, &silu_and_mul_clamp); + ops.def( "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index e7de7731286f..3f1d45ba8e9e 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -16,6 +16,7 @@ NewGELU, QuickGELU, SiluAndMul, + SiluAndMulWithClamp, SwigluOAIAndMul, SwigluStepAndMul, swiglustep_and_mul_triton, @@ -116,6 +117,85 @@ def _get_rtol(output) -> float: opcheck(fn, (out, x)) +SWIGLU_LIMITS = [3.0, 7.0, 15.0] + + +@pytest.mark.parametrize("swiglu_limit", SWIGLU_LIMITS) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_silu_and_mul_with_clamp( + default_vllm_config, + swiglu_limit: float, + num_tokens: int, + d: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + """SiluAndMulWithClamp: cuda kernel must match native reference.""" + set_random_seed(seed) + torch.set_default_device(device) + # Use large values to ensure clamping is exercised. + x = torch.randn(num_tokens, 2 * d, dtype=dtype) * swiglu_limit * 2 + + layer = SiluAndMulWithClamp(swiglu_limit, compile_native=False) + out = layer(x) + ref_out = layer.forward_native(x) + + rtol = { + torch.float16: 2e-3, + torch.bfloat16: 2e-2, + torch.float: 1.3e-6, + } + torch.testing.assert_close( + out, ref_out, atol=get_default_atol(out), rtol=rtol[out.dtype] + ) + + # Verify clamping is actually being applied: the clamped output should + # differ from the unclamped SiluAndMul output when inputs are large. + unclamped_out = SiluAndMul.forward_native(x) + assert not torch.equal(ref_out.float(), unclamped_out.float()), ( + "Input was not large enough to exercise the clamp; increase scale" + ) + + # Verify gate clamping semantics with a controlled scalar case. + # gate=large_val is clamped to limit first, then silu(limit) * 1.0. + x_gate = torch.tensor( + [[swiglu_limit * 20.0, 1.0]], dtype=torch.float32, device=device + ) + out_gate = SiluAndMulWithClamp(swiglu_limit, compile_native=False)(x_gate) + expected_gate = torch.nn.functional.silu( + torch.tensor(swiglu_limit, dtype=torch.float32) + ).item() + torch.testing.assert_close( + out_gate, + torch.tensor([[expected_gate]], dtype=torch.float32, device=device), + atol=1e-3, + rtol=1e-3, + ) + + # Verify up clamping semantics: up >> limit gets clamped to limit. + x_up = torch.tensor( + [[1.0, swiglu_limit * 20.0]], dtype=torch.float32, device=device + ) + out_up = SiluAndMulWithClamp(swiglu_limit, compile_native=False)(x_up) + silu_1 = torch.nn.functional.silu(torch.tensor(1.0)).item() + torch.testing.assert_close( + out_up, + torch.tensor([[silu_1 * swiglu_limit]], dtype=torch.float32, device=device), + atol=1e-3, + rtol=1e-3, + ) + + # opcheck + out_buf = torch.empty(x.shape[:-1] + (d,), dtype=dtype, device=device) + opcheck(torch.ops._C.silu_and_mul_with_clamp, (out_buf, x, swiglu_limit)) + + @pytest.mark.parametrize( "activation", [ diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index e2b70b771a12..59cc95f18c58 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -151,6 +151,46 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: return self.forward_cuda(x) +@CustomOp.register("silu_and_mul_with_clamp") +class SiluAndMulWithClamp(CustomOp): + """SwiGLU activation with input clamping (used by some MoE shared experts). + + Computes: + gate = clamp(x[..., :d], max=swiglu_limit) + up = clamp(x[..., d:], min=-swiglu_limit, max=swiglu_limit) + out = silu(gate) * up + where d = x.shape[-1] // 2. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self, swiglu_limit: float, *, compile_native: bool = True): + super().__init__(compile_native=compile_native) + self.swiglu_limit = float(swiglu_limit) + if current_platform.is_cuda_alike() or current_platform.is_xpu(): + self.op = torch.ops._C.silu_and_mul_with_clamp + elif current_platform.is_cpu(): + self._forward_method = self.forward_native + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + gate = torch.clamp(x[..., :d], max=self.swiglu_limit) + up = torch.clamp(x[..., d:], min=-self.swiglu_limit, max=self.swiglu_limit) + return F.silu(gate) * up + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x, self.swiglu_limit) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_cuda(x) + + # --8<-- [start:mul_and_silu] @CustomOp.register("mul_and_silu") class MulAndSilu(CustomOp): diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index e1bedd6f45be..985f33e10098 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -45,7 +45,7 @@ def _gelu_and_mul( # Uses static methods or standalone functions to avoid instantiating CustomOp # classes, which would call get_current_vllm_config() before config is set. _CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = { - MoEActivation.SILU: SiluAndMul.forward_native, + MoEActivation.SILU: lambda x: SiluAndMul(compile_native=False).forward_native(x), MoEActivation.SWIGLUOAI: _swigluoai_forward_native, MoEActivation.GELU: _gelu_and_mul, } diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 4aba5771b194..7733252804b7 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -17,6 +17,7 @@ get_tensor_model_parallel_world_size, ) from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClamp from vllm.model_executor.layers.deepseek_v4_attention import ( DeepseekV4Indexer, DeepseekV4MLAModules, @@ -34,7 +35,10 @@ RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4MoEMethod from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -46,7 +50,6 @@ VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLP from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -63,6 +66,57 @@ ) +class DeepseekV4MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + swiglu_limit: float | None = None, + quant_config: QuantizationConfig | None = None, + reduce_results: bool = True, + is_sequence_parallel: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + + # If is_sequence_parallel, the input and output tensors are sharded + # across the ranks within the tp_group. In this case the weights are + # replicated and no collective ops are needed. + # Otherwise we use standard TP with an allreduce at the end. + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + if swiglu_limit is not None: + self.act_fn = SiluAndMulWithClamp(swiglu_limit) + else: + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + class DeepseekV4FP8Config(Fp8Config): """FP8 config that routes MoE layers to MXFP4 quantization. @@ -672,10 +726,11 @@ def __init__( else: intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV2MLP( + self.shared_experts = DeepseekV4MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, + swiglu_limit=self.swiglu_limit, quant_config=quant_config, reduce_results=self.use_mega_moe, prefix=f"{prefix}.shared_experts",