diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 8535186cc1ec..86c62ae7d28b 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -387,6 +387,7 @@ set(VLLM_EXT_SRC "csrc/cpu/layernorm.cpp" "csrc/cpu/mla_decode.cpp" "csrc/cpu/pos_encoding.cpp" + "csrc/cpu/mamba_cpu.cpp" "csrc/moe/dynamic_4bit_int_moe_cpu.cpp" "csrc/cpu/cpu_attn.cpp" "csrc/cpu/torch_bindings.cpp") @@ -422,6 +423,7 @@ if (ENABLE_X86_ISA) "csrc/cpu/spec_decode_utils.cpp" "csrc/cpu/cpu_attn.cpp" "csrc/cpu/dnnl_kernels.cpp" + "csrc/cpu/mamba_cpu.cpp" "csrc/cpu/torch_bindings.cpp" # TODO: Remove these files "csrc/cpu/activation.cpp" @@ -434,6 +436,7 @@ if (ENABLE_X86_ISA) "csrc/cpu/utils.cpp" "csrc/cpu/spec_decode_utils.cpp" "csrc/cpu/cpu_attn.cpp" + "csrc/cpu/mamba_cpu.cpp" "csrc/cpu/torch_bindings.cpp" # TODO: Remove these files "csrc/cpu/activation.cpp" diff --git a/csrc/cpu/mamba_cpu.cpp b/csrc/cpu/mamba_cpu.cpp new file mode 100644 index 000000000000..716da80c315f --- /dev/null +++ b/csrc/cpu/mamba_cpu.cpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// +// CPU at::Tensor wrappers for Mamba decode-step kernels defined in +// mamba_kernels.hpp. + +#include "cpu/mamba_kernels.hpp" + +#include +#include +#include + +#include "cpu_types.hpp" + +// --------------------------------------------------------------------------- +// causal_conv1d_update +// --------------------------------------------------------------------------- +at::Tensor causal_conv1d_update_cpu_impl( + at::Tensor& x, // modified in-place (re-typed to float32) + at::Tensor& conv_state, const at::Tensor& weight, + const c10::optional& bias, + const c10::optional& activation, + const c10::optional& conv_state_indices, + const c10::optional& query_start_loc, int64_t pad_slot_id) { + + bool do_silu = false; + if (activation.has_value()) { + const std::string& act = activation.value(); + do_silu = (act == "silu" || act == "swish"); + } + + // Causal conv still works in float32 for now (minimal overhead compared to SSM) + at::Tensor x_f32 = x.to(at::kFloat).contiguous(); + at::Tensor state_f32 = conv_state.to(at::kFloat).contiguous(); + at::Tensor w_f32 = weight.to(at::kFloat).contiguous(); + at::Tensor bias_f32; + if (bias.has_value() && bias.value().defined()) bias_f32 = bias.value().to(at::kFloat).contiguous(); + + int64_t batch = x_f32.size(0); + int64_t dim = x_f32.size(1); + int64_t seqlen = (x_f32.dim() == 3) ? x_f32.size(2) : 1; + int64_t width = w_f32.size(1); + int64_t state_len = state_f32.size(2); + + at::Tensor out_f32 = at::empty_like(x_f32); + + const int32_t* cache_idx_ptr = nullptr; + at::Tensor cache_idx_int; + if (conv_state_indices.has_value()) { + cache_idx_int = conv_state_indices.value().to(at::kInt).contiguous(); + cache_idx_ptr = cache_idx_int.data_ptr(); + } + + mamba_cpu::causal_conv1d_update_kernel( + x_f32.data_ptr(), state_f32.data_ptr(), + w_f32.data_ptr(), bias_f32.defined() ? bias_f32.data_ptr() : nullptr, + out_f32.data_ptr(), cache_idx_ptr, static_cast(pad_slot_id), + batch, dim, seqlen, width, state_len, do_silu); + + conv_state.copy_(state_f32.to(conv_state.scalar_type())); + return out_f32.to(x.scalar_type()); +} + +// --------------------------------------------------------------------------- +// selective_state_update +// --------------------------------------------------------------------------- +void selective_state_update_cpu_impl( + at::Tensor& state, // (nstates, nheads, dim, dstate) + const at::Tensor& x, // (N, nheads, dim) + const at::Tensor& dt, + const at::Tensor& A, + const at::Tensor& B, + const at::Tensor& C, + const c10::optional& D, + const c10::optional& z, + const c10::optional& dt_bias, + bool dt_softplus, + const c10::optional& state_batch_indices, + const c10::optional& dst_state_batch_indices, + int64_t null_block_id, + at::Tensor& out, + const c10::optional& num_accepted_tokens, + const c10::optional& cu_seqlens +) { + // Use state's dtype as the primary type to avoid expensive conversions + // The kernel supports mixed types: state_t can be BFloat16 while input_t matches x + at::ScalarType state_type = state.scalar_type(); + at::ScalarType input_type = x.scalar_type(); + + // Only convert/contiguous if needed to minimize overhead + auto ensure_type_and_contiguous = [input_type](const at::Tensor& t) -> at::Tensor { + if (t.scalar_type() != input_type) { + return t.to(input_type).contiguous(); + } + return t.is_contiguous() ? t : t.contiguous(); + }; + + at::Tensor dt_in = ensure_type_and_contiguous(dt); + at::Tensor A_in = ensure_type_and_contiguous(A); + at::Tensor B_in = ensure_type_and_contiguous(B); + at::Tensor C_in = ensure_type_and_contiguous(C); + + at::Tensor D_in, z_in, dt_bias_in; + if (D.has_value() && D.value().defined()) { + D_in = ensure_type_and_contiguous(D.value()); + } + if (z.has_value() && z.value().defined()) { + z_in = ensure_type_and_contiguous(z.value()); + } + if (dt_bias.has_value() && dt_bias.value().defined()) { + dt_bias_in = ensure_type_and_contiguous(dt_bias.value()); + } + + int64_t nheads = state.size(1); + int64_t dim = state.size(2); + int64_t dstate = state.size(3); + int64_t N = x.size(0); + int64_t ngroups = B_in.size(1); + + // Strides + int64_t stride_state_n = state.stride(0); + int64_t stride_state_h = state.stride(1); + int64_t stride_state_d = state.stride(2); + int64_t stride_xdt_n = x.stride(0); + int64_t stride_xdt_h = x.stride(1); + int64_t stride_A_h = A_in.stride(0); + int64_t stride_BC_n = B_in.stride(0); + int64_t stride_BC_g = B_in.stride(1); + int64_t stride_out_n = out.stride(0); + int64_t stride_out_h = out.stride(1); + int64_t stride_D_h = D_in.defined() ? D_in.stride(0) : 0; + int64_t stride_dtbias_h = dt_bias_in.defined() ? dt_bias_in.stride(0) : 0; + + // Optional pointers - extract once + auto get_int32_ptr = [](const c10::optional& opt) -> const int32_t* { + return (opt.has_value() && opt.value().defined()) ? opt.value().data_ptr() : nullptr; + }; + + const int32_t* sbi_ptr = get_int32_ptr(state_batch_indices); + const int32_t* dsbi_ptr = get_int32_ptr(dst_state_batch_indices); + const int32_t* nat_ptr = get_int32_ptr(num_accepted_tokens); + const int32_t* csl_ptr = get_int32_ptr(cu_seqlens); + + // Optimize output buffer: only use float32 if output type is not already float32 + // This avoids an extra copy when out is already float32 + bool need_out_conversion = (out.scalar_type() != at::kFloat); + at::Tensor out_f32 = need_out_conversion ? at::empty_like(out, at::kFloat) : out; + + VLLM_DISPATCH_FLOATING_TYPES(state_type, "ssu_state", [&] { + using state_t = scalar_t; + VLLM_DISPATCH_FLOATING_TYPES(input_type, "ssu_input", [&] { + using input_t = scalar_t; + mamba_cpu::selective_state_update_kernel( + state.data_ptr(), stride_state_n, stride_state_h, stride_state_d, + x.data_ptr(), dt_in.data_ptr(), stride_xdt_n, stride_xdt_h, + A_in.data_ptr(), stride_A_h, + B_in.data_ptr(), C_in.data_ptr(), stride_BC_n, stride_BC_g, + D_in.defined() ? D_in.data_ptr() : nullptr, stride_D_h, + z_in.defined() ? z_in.data_ptr() : nullptr, + dt_bias_in.defined() ? dt_bias_in.data_ptr() : nullptr, stride_dtbias_h, + out_f32.data_ptr(), stride_out_n, stride_out_h, + sbi_ptr, dsbi_ptr, static_cast(null_block_id), + nat_ptr, csl_ptr, N, nheads, ngroups, dim, dstate, dt_softplus); + }); + }); + + // Only copy back if we used a temporary buffer + if (need_out_conversion) { + out.copy_(out_f32); + } +} diff --git a/csrc/cpu/mamba_kernels.hpp b/csrc/cpu/mamba_kernels.hpp new file mode 100644 index 000000000000..f351f9acade8 --- /dev/null +++ b/csrc/cpu/mamba_kernels.hpp @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// +// Fused CPU vector kernels for Mamba decode-step hotspots: +// - causal_conv1d_update (depthwise 1-D conv state roll + compute) +// - selective_state_update (SSM recurrence, single-step) + +#pragma once + +#include "cpu_types.hpp" +#include +#include +#include +#include +#include + +namespace mamba_cpu { + +// --------------------------------------------------------------------------- +// causal_conv1d_update +// --------------------------------------------------------------------------- +inline void causal_conv1d_update_kernel( + const float* __restrict__ x_ptr, float* __restrict__ state_ptr, + const float* __restrict__ weight_ptr, const float* __restrict__ bias_ptr, + float* __restrict__ out_ptr, const int32_t* __restrict__ cache_idxs, + int32_t pad_slot_id, int64_t batch, int64_t dim, int64_t seqlen, + int64_t width, int64_t state_len, bool do_silu) { + +#pragma omp parallel for + for (int64_t b = 0; b < batch; ++b) { + int64_t cache_idx = (cache_idxs != nullptr) ? cache_idxs[b] : b; + if (cache_idx == pad_slot_id) continue; + + for (int64_t t = 0; t < seqlen; ++t) { + const float* x_b = x_ptr + (b * dim * seqlen + t); + float* out_b = out_ptr + (b * dim * seqlen + t); + float* s = state_ptr + cache_idx * dim * state_len; + + for (int64_t d = 0; d < dim; ++d) { + float x_val = x_b[d * seqlen]; + float* sd = s + d * state_len; + + const float* w = weight_ptr + d * width; + float acc = (bias_ptr != nullptr) ? bias_ptr[d] : 0.0f; + + for (int64_t k = 0; k < state_len; ++k) { + acc += w[k] * sd[k]; + } + acc += w[state_len] * x_val; + + if (state_len > 1) { + std::memmove(sd, sd + 1, (state_len - 1) * sizeof(float)); + } + if (state_len > 0) { + sd[state_len - 1] = x_val; + } + + if (do_silu) { + float sigmoid = (acc >= 0) ? + 1.0f / (1.0f + std::exp(-acc)) : + std::exp(acc) / (1.0f + std::exp(acc)); + acc *= sigmoid; + } + out_b[d * seqlen] = acc; + } + } + } +} + +// --------------------------------------------------------------------------- +// selective_state_update +// --------------------------------------------------------------------------- +template +inline void selective_state_update_kernel( + state_t* __restrict__ state_ptr, + int64_t stride_state_n, int64_t stride_state_h, int64_t stride_state_d, + const input_t* __restrict__ x_ptr, const input_t* __restrict__ dt_ptr, + int64_t stride_xdt_n, int64_t stride_xdt_h, + const input_t* __restrict__ A_ptr, int64_t stride_A_h, + const input_t* __restrict__ B_ptr, const input_t* __restrict__ C_ptr, + int64_t stride_BC_n, int64_t stride_BC_g, + const input_t* __restrict__ D_ptr, int64_t stride_D_h, + const input_t* __restrict__ z_ptr, + const input_t* __restrict__ dt_bias_ptr, int64_t stride_dtbias_h, + float* __restrict__ out_ptr, int64_t stride_out_n, int64_t stride_out_h, + const int32_t* __restrict__ state_batch_indices, + const int32_t* __restrict__ dst_state_batch_indices, int32_t null_block_id, + const int32_t* __restrict__ num_accepted_tokens, + const int32_t* __restrict__ cu_seqlens, + int64_t N, int64_t nheads, int64_t ngroups, int64_t dim, int64_t dstate, + bool dt_softplus) { + + using state_vec_t = vec_op::vec_t; + using input_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = 8; + + int64_t nheads_per_group = nheads / ngroups; + + for (int64_t seq_idx = 0; seq_idx < N; ++seq_idx) { + int64_t bos, seq_len; + if (cu_seqlens != nullptr) { + bos = cu_seqlens[seq_idx]; + seq_len = cu_seqlens[seq_idx + 1] - bos; + } else { + bos = seq_idx; + seq_len = 1; + } + + int64_t state_read_idx = (state_batch_indices != nullptr) ? + state_batch_indices[seq_idx] : seq_idx; + if (state_read_idx == null_block_id) continue; + + int64_t state_write_idx = (num_accepted_tokens == nullptr) ? + ((dst_state_batch_indices != nullptr) ? dst_state_batch_indices[seq_idx] : state_read_idx) : -1; + + state_t* s = state_ptr + state_read_idx * stride_state_n; + + for (int64_t t = 0; t < seq_len; ++t) { + int64_t token_idx = bos + t; + const input_t* x_tok = x_ptr + token_idx * stride_xdt_n; + const input_t* dt_tok = dt_ptr + token_idx * stride_xdt_n; + const input_t* B_tok = B_ptr + token_idx * stride_BC_n; + const input_t* C_tok = C_ptr + token_idx * stride_BC_n; + float* out_tok = out_ptr + token_idx * stride_out_n; + +#pragma omp parallel for + for (int64_t h = 0; h < nheads; ++h) { + int64_t g = h / nheads_per_group; + const input_t* x_h = x_tok + h * stride_xdt_h; + const input_t* dt_h = dt_tok + h * stride_xdt_h; + const input_t* B_g = B_tok + g * stride_BC_g; + const input_t* C_g = C_tok + g * stride_BC_g; + const input_t* A_h = A_ptr + h * stride_A_h; + const input_t* dt_bias_h = (dt_bias_ptr != nullptr) ? dt_bias_ptr + h * stride_dtbias_h : nullptr; + const input_t* D_h = (D_ptr != nullptr) ? D_ptr + h * stride_D_h : nullptr; + const input_t* z_h = (z_ptr != nullptr) ? z_ptr + token_idx * stride_xdt_n + h * stride_xdt_h : nullptr; + float* out_h = out_tok + h * stride_out_h; + state_t* s_h = s + h * stride_state_h; + + for (int64_t d = 0; d < dim; ++d) { + float x_val = static_cast(x_h[d]); + float dt_val = static_cast(dt_h[d]); + if (dt_bias_h != nullptr) dt_val += static_cast(dt_bias_h[d]); + if (dt_softplus) { + dt_val = (dt_val <= 20.0f) ? std::log1p(std::exp(dt_val)) : dt_val; + } + + vec_op::FP32Vec8 out_vec(0.0f); + state_t* s_hd = s_h + d * stride_state_d; + const input_t* A_hd = A_h + d * dstate; + + vec_op::FP32Vec8 x_vec(x_val); + vec_op::FP32Vec8 dt_vec(dt_val); + + int64_t n = 0; + for (; n <= dstate - VEC_ELEM_NUM; n += VEC_ELEM_NUM) { + vec_op::FP32Vec8 A_v((input_vec_t(A_hd + n))); + vec_op::FP32Vec8 B_v((input_vec_t(B_g + n))); + vec_op::FP32Vec8 C_v((input_vec_t(C_g + n))); + + vec_op::FP32Vec8 s_v((state_vec_t(s_hd + n))); + + vec_op::FP32Vec8 dA = (A_v * dt_vec).exp(); + vec_op::FP32Vec8 dBx = B_v * x_vec * dt_vec; + vec_op::FP32Vec8 s_new = s_v * dA + dBx; + + state_vec_t(s_new).save(s_hd + n); + out_vec = out_vec + s_new * C_v; + } + + float out_val = out_vec.reduce_sum(); + for (; n < dstate; ++n) { + float dA = std::exp(static_cast(A_hd[n]) * dt_val); + float dBx = static_cast(B_g[n]) * x_val * dt_val; + float s_new = static_cast(s_hd[n]) * dA + dBx; + s_hd[n] = static_cast(s_new); + out_val += s_new * static_cast(C_g[n]); + } + + if (D_h != nullptr) out_val += x_val * static_cast(D_h[d]); + if (z_h != nullptr) { + float z_val = static_cast(z_h[d]); + float sigmoid = (z_val >= 0) ? 1.0f / (1.0f + std::exp(-z_val)) : std::exp(z_val) / (1.0f + std::exp(z_val)); + out_val *= z_val * sigmoid; + } + out_h[d] = out_val; + } + } + + if (num_accepted_tokens != nullptr && dst_state_batch_indices != nullptr) { + int64_t token_dst_idx = dst_state_batch_indices[seq_idx * seq_len + t]; + if (token_dst_idx != null_block_id && token_dst_idx != state_read_idx) { + state_t* dst_s = state_ptr + token_dst_idx * stride_state_n; + std::memmove(dst_s, s, nheads * stride_state_h * sizeof(state_t)); + } + } + } + + if (num_accepted_tokens == nullptr && state_write_idx != null_block_id && state_write_idx != state_read_idx) { + state_t* dst_s = state_ptr + state_write_idx * stride_state_n; + std::memmove(dst_s, s, nheads * stride_state_h * sizeof(state_t)); + } + } +} + +} // namespace mamba_cpu diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 89c5765ee646..9f895ac6bb4a 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -144,6 +144,24 @@ void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc, torch::Tensor slot_mapping, const int64_t block_size); +at::Tensor causal_conv1d_update_cpu_impl( + at::Tensor& x, at::Tensor& conv_state, const at::Tensor& weight, + const c10::optional& bias, + const c10::optional& activation, + const c10::optional& conv_state_indices, + const c10::optional& query_start_loc, int64_t pad_slot_id); + +void selective_state_update_cpu_impl( + at::Tensor& state, const at::Tensor& x, const at::Tensor& dt, + const at::Tensor& A, const at::Tensor& B, const at::Tensor& C, + const c10::optional& D, const c10::optional& z, + const c10::optional& dt_bias, bool dt_softplus, + const c10::optional& state_batch_indices, + const c10::optional& dst_state_batch_indices, + int64_t null_block_id, at::Tensor& out, + const c10::optional& num_accepted_tokens, + const c10::optional& cu_seqlens); + void init_cpu_memory_env(std::vector node_ids); namespace cpu_utils { @@ -440,6 +458,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "block_size) -> ()", &compute_slot_mapping_kernel_impl); + // Mamba CPU kernels + ops.def( + "causal_conv1d_update_cpu(" + "Tensor(a0!) x, Tensor(a1!) conv_state, Tensor weight, " + "Tensor? bias, str? activation, Tensor? conv_state_indices, " + "Tensor? query_start_loc, SymInt pad_slot_id) -> Tensor", + &causal_conv1d_update_cpu_impl); + + ops.def( + "selective_state_update_cpu(" + "Tensor(a0!) state, Tensor x, Tensor dt, Tensor A, Tensor B, Tensor C, " + "Tensor? D, Tensor? z, Tensor? dt_bias, bool dt_softplus, " + "Tensor? state_batch_indices, Tensor? dst_state_batch_indices, " + "SymInt null_block_id, Tensor(a13!) out, " + "Tensor? num_accepted_tokens, Tensor? cu_seqlens) -> ()", + &selective_state_update_cpu_impl); + ops.def("init_cpu_memory_env(SymInt[] node_ids) -> ()", &init_cpu_memory_env); // Speculative decoding kernels diff --git a/scratch/verify_cpu_ssu.py b/scratch/verify_cpu_ssu.py new file mode 100644 index 000000000000..c5ff0fe276bc --- /dev/null +++ b/scratch/verify_cpu_ssu.py @@ -0,0 +1,107 @@ +import torch +import torch.nn.functional as F +from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update +import time + +def selective_state_update_ref( + state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False +): + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + + dt = dt.float() + if dt_bias is not None: + dt = dt + dt_bias.float() + dt = F.softplus(dt) if dt_softplus else dt + + dA = torch.exp(dt.unsqueeze(-1) * A.float()) # (batch, nheads, dim, dstate) + ngroups = B.shape[1] + B = B.repeat_interleave(nheads // ngroups, dim=1) + C = C.repeat_interleave(nheads // ngroups, dim=1) + + dBx = (dt.unsqueeze(-1) * B.float().unsqueeze(2)) * x.float().unsqueeze(-1) + + state.copy_(state.float() * dA + dBx) + + out = torch.einsum("bhdn,bhn->bhd", state.float(), C.float()) + if D is not None: + out += (x.float() * D.float()) + + if z is not None: + out = out * F.silu(z.float()) + + if not has_heads: + out = out.squeeze(1) + return out.to(x.dtype) + +def test_cpu_ssu(): + device = "cpu" + torch.manual_seed(42) + + # Mamba-2 style dimensions + batch_size = 2 + nheads = 128 + dim = 64 + dstate = 128 + itype = torch.bfloat16 + + print(f"Testing CPU selective_state_update with {itype}...") + + state = torch.randn(batch_size, nheads, dim, dstate, dtype=itype, device=device) + x = torch.randn(batch_size, nheads, dim, device=device, dtype=itype) + dt = torch.randn(batch_size, nheads, dim, device=device, dtype=itype) + dt_bias = torch.rand(nheads, dim, device=device) - 4.0 + A = -torch.rand(nheads, dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, 1, dstate, device=device, dtype=itype) + C = torch.randn(batch_size, 1, dstate, device=device, dtype=itype) + D = torch.randn(nheads, dim, device=device) + z = torch.randn_like(x) + + out = torch.empty_like(x) + state_ref = state.detach().clone() + + start = time.time() + selective_state_update( + state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, out=out + ) + end = time.time() + print(f"Kernel time: {end - start:.4f}s") + + start = time.time() + out_ref = selective_state_update_ref( + state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True + ) + end = time.time() + print(f"Ref time: {end - start:.4f}s") + + state_diff = (state.float() - state_ref.float()).abs().max().item() + out_diff = (out.float() - out_ref.float()).abs().max().item() + + print(f"State max diff: {state_diff}") + print(f"Out max diff: {out_diff}") + + if state_diff < 1e-2 and out_diff < 1e-2: + print("SUCCESS") + else: + print("FAILURE") + +if __name__ == "__main__": + test_cpu_ssu() diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ffea35de8811..f23f69a5c34c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2305,6 +2305,66 @@ def selective_scan_fwd( ) +def causal_conv1d_update_cpu( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + activation: str | None = None, + conv_state_indices: torch.Tensor | None = None, + query_start_loc: torch.Tensor | None = None, + pad_slot_id: int = 0, +) -> torch.Tensor: + return torch.ops._C.causal_conv1d_update_cpu( + x, + conv_state, + weight, + bias, + activation, + conv_state_indices, + query_start_loc, + pad_slot_id, + ) + + +def selective_state_update_cpu( + state: torch.Tensor, + x: torch.Tensor, + dt: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D: torch.Tensor | None, + z: torch.Tensor | None, + dt_bias: torch.Tensor | None, + dt_softplus: bool, + state_batch_indices: torch.Tensor | None, + dst_state_batch_indices: torch.Tensor | None, + null_block_id: int, + out: torch.Tensor, + num_accepted_tokens: torch.Tensor | None, + cu_seqlens: torch.Tensor | None, +): + torch.ops._C.selective_state_update_cpu( + state, + x, + dt, + A, + B, + C, + D, + z, + dt_bias, + dt_softplus, + state_batch_indices, + dst_state_batch_indices, + null_block_id, + out, + num_accepted_tokens, + cu_seqlens, + ) + + # ROCm skinny gemms def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor: return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 985f33e10098..3609d350da73 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -8,7 +8,6 @@ from vllm import _custom_ops as ops from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.quantization.utils.layer_utils import replace_parameter from vllm.utils.torch_utils import direct_register_custom_op @@ -41,11 +40,18 @@ def _gelu_and_mul( return F.gelu(x[..., :d], approximate="none") * x[..., d:] +def _silu_and_mul(x: torch.Tensor) -> torch.Tensor: + """Standalone SiluAndMul forward to avoid instantiating CustomOp + (which calls get_current_vllm_config()) at model-forward time.""" + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + # Map activation names to their native forward functions. # Uses static methods or standalone functions to avoid instantiating CustomOp # classes, which would call get_current_vllm_config() before config is set. _CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = { - MoEActivation.SILU: lambda x: SiluAndMul(compile_native=False).forward_native(x), + MoEActivation.SILU: _silu_and_mul, MoEActivation.SWIGLUOAI: _swigluoai_forward_native, MoEActivation.GELU: _gelu_and_mul, } diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 0e476755201e..7800a44eed0e 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -369,6 +369,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): time_proj_bias = self._time_proj_bias() # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) + # Keep D in native dtype to avoid expensive conversions scan_out_p = selective_scan_fn( conv_out_p, ssm_state, @@ -376,7 +377,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): self.A, B_p.transpose(-2, -1), C_p.transpose(-2, -1), - self.D.float(), + self.D, gate_p, time_proj_bias, delta_softplus=True, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 2b4b1934f9b3..6ce8242dca25 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -876,10 +876,11 @@ def conv_ssm_forward( # 3. State Space Model sequence transformation n_groups = self.n_groups // self.tp_size + # Keep A in native dtype to avoid expensive conversions + # The SSU kernel handles mixed types efficiently A_d = ( self.A[:, None, ...][:, :, None] .expand(-1, self.head_dim, self.ssm_state_size) - .to(dtype=torch.float32) ) dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 1160105ad101..fad11e67384d 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -8,9 +8,13 @@ import numpy as np import torch +from vllm import _custom_ops as ops from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.utils import NULL_BLOCK_ID, PAD_SLOT_ID +from .cpu_fallbacks import _causal_conv1d_fn_cpu, _causal_conv1d_update_cpu + + @triton.jit() def _causal_conv1d_fwd_kernel( # continuous batching @@ -466,7 +470,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching tl.store(o_ptrs, acc, mask=mask_1d) -def causal_conv1d_fn( +def _causal_conv1d_fn_cuda( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None, @@ -1068,7 +1072,7 @@ def _causal_conv1d_update_kernel( tl.store(o_ptrs, acc, mask=mask_1d) -def causal_conv1d_update( +def _causal_conv1d_update_cuda( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, @@ -1239,3 +1243,86 @@ def grid(META): if unsqueeze: out = out.squeeze(-1) return out.to(original_x_dtype) + + +def causal_conv1d_fn(*args, **kwargs): + """Dispatch causal_conv1d_fn to CPU PyTorch fallback or CUDA Triton kernel.""" + x = args[0] if args else kwargs.get("x") + if x is not None and x.device.type == "cpu": + return _causal_conv1d_fn_cpu(*args, **kwargs) + return _causal_conv1d_fn_cuda(*args, **kwargs) + + +def causal_conv1d_update( + x, + conv_state, + weight, + bias=None, + activation=None, + conv_state_indices=None, + num_accepted_tokens=None, + query_start_loc=None, + max_query_len=-1, + null_block_id=NULL_BLOCK_ID, + block_idx_last_scheduled_token=None, + initial_state_idx=None, + validate_data=False, +): + """Dispatch causal_conv1d_update to CPU C++ kernel or CUDA Triton kernel.""" + if x.device.type == "cpu": + # The C++ kernel handles the standard (non-varlen, non-spec-decoding) + # decode path. Fall back to PyTorch for the complex varlen / + # spec-decoding paths. + if query_start_loc is not None or num_accepted_tokens is not None: + return _causal_conv1d_update_cpu( + x, + conv_state, + weight, + bias=bias, + activation=activation, + conv_state_indices=conv_state_indices, + query_start_loc=query_start_loc, + ) + # Determine activation string + act_str = None + if isinstance(activation, bool): + act_str = "silu" if activation else None + elif activation is not None: + act_str = activation + + _conv_state_indices = conv_state_indices + if _conv_state_indices is not None and _conv_state_indices.dim() == 2: + if initial_state_idx is not None: + _conv_state_indices = _conv_state_indices.gather( + 1, initial_state_idx.unsqueeze(1) + ).squeeze(1) + else: + _conv_state_indices = _conv_state_indices[:, 0] + + pad_slot_id = int(NULL_BLOCK_ID) + return ops.causal_conv1d_update_cpu( + x, + conv_state, + weight, + bias, + act_str, + _conv_state_indices, + None, # query_start_loc + pad_slot_id, + ) + + return _causal_conv1d_update_cuda( + x, + conv_state, + weight, + bias=bias, + activation=activation, + conv_state_indices=conv_state_indices, + num_accepted_tokens=num_accepted_tokens, + query_start_loc=query_start_loc, + max_query_len=max_query_len, + null_block_id=null_block_id, + block_idx_last_scheduled_token=block_idx_last_scheduled_token, + initial_state_idx=initial_state_idx, + validate_data=validate_data, + ) diff --git a/vllm/model_executor/layers/mamba/ops/cpu_fallbacks.py b/vllm/model_executor/layers/mamba/ops/cpu_fallbacks.py new file mode 100644 index 000000000000..c33a574395ce --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/cpu_fallbacks.py @@ -0,0 +1,390 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.v1.attention.backends.utils import NULL_BLOCK_ID, PAD_SLOT_ID + + +def _causal_conv1d_fn_cpu( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + conv_states: torch.Tensor, + query_start_loc: torch.Tensor, + cache_indices: torch.Tensor | None = None, + has_initial_state: torch.Tensor | None = None, + activation: str | None = "silu", + pad_slot_id: int = PAD_SLOT_ID, + **kwargs, +) -> torch.Tensor: + """Pure PyTorch CPU fallback for causal_conv1d_fwd.""" + if isinstance(activation, bool) and activation: + activation = "silu" + + original_x_dtype = x.dtype + x = x.to(conv_states.dtype) + + dim, cu_seqlen = x.shape + _, width = weight.shape + state_len = width - 1 + + out = torch.zeros_like(x) + + batch = query_start_loc.size(0) - 1 + + for b in range(batch): + seq_start = query_start_loc[b].item() + seq_end = query_start_loc[b + 1].item() + seq_len = seq_end - seq_start + + if seq_len == 0: + continue + + cache_idx = cache_indices[b].item() if cache_indices is not None else b + + if cache_idx == pad_slot_id: + continue + + x_seq = x[:, seq_start:seq_end] # (dim, seq_len) + + if has_initial_state is not None and has_initial_state[b]: + state = conv_states[cache_idx].clone() # (dim, state_len) + else: + state = torch.zeros((dim, state_len), dtype=x.dtype, device=x.device) + + for t in range(seq_len): + x_t = x_seq[:, t] # (dim,) + + window = torch.cat([state, x_t.unsqueeze(1)], dim=1) # (dim, width) + val = (window * weight).sum(dim=1) # (dim,) + + if bias is not None: + val = val + bias + if activation in ["silu", "swish"]: + val = val * torch.sigmoid(val) + + out[:, seq_start + t] = val + + if state_len > 1: + state[:, :-1] = state[:, 1:].clone() + state[:, -1] = x_t + + if seq_len >= state_len: + conv_states[cache_idx, :, :state_len] = x_seq[:, -state_len:] + else: + conv_states[cache_idx, :, : state_len - seq_len] = conv_states[ + cache_idx, :, seq_len:state_len + ].clone() + conv_states[cache_idx, :, state_len - seq_len :] = x_seq + + return out.to(original_x_dtype) + + +def _causal_conv1d_update_cpu( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + activation: bool | str | None = None, + conv_state_indices: torch.Tensor | None = None, + query_start_loc: torch.Tensor | None = None, + pad_slot_id: int = PAD_SLOT_ID, + **kwargs, +) -> torch.Tensor: + """Pure PyTorch CPU fallback for causal_conv1d_update (decode path).""" + if isinstance(activation, bool): + activation = "silu" if activation else None + + original_x_dtype = x.dtype + x = x.to(conv_state.dtype) + _, width = weight.shape + state_len = width - 1 + + if query_start_loc is None and x.dim() == 2: + x = x.unsqueeze(-1) + unsqueeze = True + else: + unsqueeze = False + + if query_start_loc is None: + batch, dim, seqlen = x.shape + + if conv_state_indices is not None: + cache_idxs = conv_state_indices.flatten() + valid_mask = cache_idxs != pad_slot_id + else: + cache_idxs = torch.arange(batch, device=x.device) + valid_mask = torch.ones(batch, dtype=torch.bool, device=x.device) + + for t in range(seqlen): + x_t = x[:, :, t].clone() + + states = conv_state[cache_idxs] + + windows = torch.cat([states, x_t.unsqueeze(-1)], dim=-1) + + val = (windows * weight.unsqueeze(0)).sum(dim=-1) + if bias is not None: + val = val + bias.unsqueeze(0) + + if activation in ["silu", "swish"]: + val = val * torch.sigmoid(val) + + val = val * valid_mask.unsqueeze(-1).to(val.dtype) + x[:, :, t] = val + + new_state = torch.cat([states[:, :, 1:], x_t.unsqueeze(-1)], dim=-1) + conv_state[cache_idxs[valid_mask]] = new_state[valid_mask] + + out = x + if unsqueeze: + out = out.squeeze(-1) + return out.to(original_x_dtype) + + assert conv_state_indices is not None + assert query_start_loc is not None + batch = conv_state_indices.size(0) + out = x.clone() + + for b in range(batch): + cache_idx = conv_state_indices[b].item() + if cache_idx == pad_slot_id: + continue + + seq_start = query_start_loc[b].item() + seq_end = query_start_loc[b + 1].item() + seqlen_b = seq_end - seq_start + + if seqlen_b == 0: + continue + + local_state = conv_state[cache_idx].clone() + + for t in range(seqlen_b): + x_t = x[seq_start + t, :] + + window = torch.cat([local_state, x_t.unsqueeze(-1)], dim=-1) + val = (window * weight).sum(dim=-1) + if bias is not None: + val = val + bias + if activation in ["silu", "swish"]: + val = val * torch.sigmoid(val) + + out[seq_start + t, :] = val + + if state_len > 1: + local_state[:, :-1] = local_state[:, 1:].clone() + local_state[:, -1] = x_t + + conv_state[cache_idx] = local_state + + return out.to(original_x_dtype) + + +def _selective_state_update_cpu( + state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + state_batch_indices=None, + dst_state_batch_indices=None, + null_block_id=NULL_BLOCK_ID, + out=None, + num_accepted_tokens=None, + cu_seqlens=None, + is_blackwell=False, + enable_stochastic_rounding=False, + cache_philox_rounds=0, + **kwargs, +): + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + if out.dim() == 2: + out = out.unsqueeze(1) + if state_batch_indices is not None and state_batch_indices.dim() == 1: + state_batch_indices = state_batch_indices.unsqueeze(1) + if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1: + dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1) + + _, nheads, dim, dstate = state.shape + batch = x.shape[0] + N = len(cu_seqlens) - 1 if cu_seqlens is not None else batch + + ngroups = B.shape[1] + nheads_ngroups_ratio = nheads // ngroups + + for seq_idx in range(N): + if cu_seqlens is not None: + bos = cu_seqlens[seq_idx].item() + seq_len = cu_seqlens[seq_idx + 1].item() - bos + else: + bos = seq_idx + seq_len = 1 + + if state_batch_indices is not None: + state_idx = state_batch_indices[seq_idx, 0].item() + if state_idx == null_block_id: + continue + else: + state_idx = seq_idx + + if num_accepted_tokens is None: + if dst_state_batch_indices is not None: + dst_idx = dst_state_batch_indices[seq_idx, 0].item() + else: + dst_idx = state_idx + + s = state[state_idx].float() + + for t in range(seq_len): + token_idx = bos + t + + x_val = x[token_idx].float() + dt_val = dt[token_idx].float() + + if dt_bias is not None: + dt_val = dt_val + dt_bias.float() + if dt_softplus: + dt_val = torch.nn.functional.softplus(dt_val) + + A_val = A.float() + + B_val = B[token_idx].float() + B_expanded = B_val.repeat_interleave(nheads_ngroups_ratio, dim=0) + C_val = C[token_idx].float() + C_expanded = C_val.repeat_interleave(nheads_ngroups_ratio, dim=0) + + dA = torch.exp(A_val * dt_val.unsqueeze(-1)) + dBx = B_expanded.unsqueeze(1) * (x_val * dt_val).unsqueeze(-1) + s = s * dA + dBx + + if num_accepted_tokens is not None: + token_dst_idx = dst_state_batch_indices[seq_idx, t].item() + if token_dst_idx != null_block_id: + state[token_dst_idx] = s.to(state.dtype) + + out_val = (s * C_expanded.unsqueeze(1)).sum(dim=-1) + + if D is not None: + out_val = out_val + x_val * D.float() + + if z is not None: + z_val = z[token_idx].float() + out_val = out_val * z_val * torch.sigmoid(z_val) + + out[token_idx] = out_val.to(out.dtype) + + if num_accepted_tokens is None and dst_idx != null_block_id: + state[dst_idx] = s.to(state.dtype) + + +def _mamba_chunk_scan_combined_fwd_cpu( + x, + dt, + A, + B, + C, + chunk_size, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + return_intermediate_states=False, + seq_idx=None, + cu_seqlens=None, + cu_chunk_seqlens=None, + last_chunk_indices=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + state_dtype=None, + **kwargs, +): + seqlen, nheads, headdim = x.shape + _, ngroups, dstate = B.shape + nheads_per_group = nheads // ngroups + + assert cu_seqlens is not None + batch = cu_seqlens.size(0) - 1 + + dt_f = dt.float() + if dt_bias is not None: + dt_f = dt_f + dt_bias.float().unsqueeze(0) + if dt_softplus: + dt_f = torch.nn.functional.softplus(dt_f) + if dt_limit[0] > 0.0 or dt_limit[1] < float("inf"): + dt_f = dt_f.clamp(min=dt_limit[0], max=dt_limit[1]) + + all_states = torch.zeros( + batch, nheads, headdim, dstate, dtype=torch.float32, device=x.device + ) + + for b_idx in range(batch): + seq_start = cu_seqlens[b_idx].item() + seq_end = cu_seqlens[b_idx + 1].item() + + if initial_states is not None: + state = initial_states[b_idx].float() + else: + state = torch.zeros( + nheads, headdim, dstate, dtype=torch.float32, device=x.device + ) + + for t in range(seq_start, seq_end): + x_t = x[t].float() + dt_t = dt_f[t] + A_val = A.float() + + dA = torch.exp(A_val * dt_t).unsqueeze(-1).unsqueeze(-1) + + B_expanded = B[t].float().repeat_interleave(nheads_per_group, dim=0) + C_expanded = C[t].float().repeat_interleave(nheads_per_group, dim=0) + + xdt = x_t * dt_t.unsqueeze(-1) + dBx = xdt.unsqueeze(-1) * B_expanded.unsqueeze(1) + state = state * dA + dBx + + y = (state * C_expanded.unsqueeze(1)).sum(dim=-1) + + if D is not None: + y = ( + y + x_t * D.float().unsqueeze(-1) + if D.dim() == 1 + else y + x_t * D.float() + ) + + if z is not None: + z_t = z[t].float() + y = y * z_t * torch.sigmoid(z_t) + + out[t] = y.to(out.dtype) + + all_states[b_idx] = state.to(all_states.dtype) + + out_dtype = state_dtype if state_dtype is not None else x.dtype + all_states = all_states.to(out_dtype) + + return all_states diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index e3c8ba8312f2..324d0ed8460c 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -48,8 +48,9 @@ def convert_rs_fp16x2(x: tl.tensor, rand: tl.tensor) -> tl.tensor: @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) @triton.heuristics( { - "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] - is not None + "HAS_STATE_BATCH_INDICES": lambda args: ( + args["state_batch_indices_ptr"] is not None + ) } ) @triton.heuristics( @@ -316,7 +317,7 @@ def _selective_scan_update_kernel( tl.store(dst_state_ptrs, state, mask=mask) -def selective_state_update( +def _selective_state_update_cuda( state, x, dt, @@ -427,6 +428,29 @@ def selective_state_update( if num_accepted_tokens is not None: assert num_accepted_tokens.shape == (N,) + if not HAS_TRITON: + return _selective_state_update_cpu( + state, + x, + dt, + A, + B, + C, + D, + z, + dt_bias, + dt_softplus, + state_batch_indices, + dst_state_batch_indices, + null_block_id, + out, + num_accepted_tokens, + cu_seqlens, + is_blackwell, + enable_stochastic_rounding, + cache_philox_rounds, + ) + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), N, nheads) z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) state_batch_indices_strides = ( @@ -656,3 +680,76 @@ def selective_scan_fn( return delta # output written inplace to delta else: return z # output written inplace to z + + + +def selective_state_update( + state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + state_batch_indices=None, + dst_state_batch_indices=None, + null_block_id=NULL_BLOCK_ID, + out=None, + num_accepted_tokens=None, + cu_seqlens=None, + is_blackwell=False, + enable_stochastic_rounding=False, + cache_philox_rounds=0, +): + """Dispatch selective_state_update to CPU C++ kernel or CUDA Triton kernel.""" + # Ensure out tensor exists + if out is None: + out = torch.empty_like(x if x.dim() == 2 else x) + + if x.device.type == "cpu": + # Reshape tensors from (batch, dim) -> (batch, 1, dim) if needed + # The C++ kernel expects (N, nheads, dim) layout + _state = state.unsqueeze(1) if state.dim() == 3 else state + _x = x.unsqueeze(1) if x.dim() == 2 else x + _dt = dt.unsqueeze(1) if dt.dim() == 2 else dt + _A = A.unsqueeze(0) if A.dim() == 2 else A + _B = B.unsqueeze(1) if B.dim() == 2 else B + _C = C.unsqueeze(1) if C.dim() == 2 else C + _D = D.unsqueeze(0) if (D is not None and D.dim() == 1) else D + _z = z.unsqueeze(1) if (z is not None and z.dim() == 2) else z + _dt_bias = ( + dt_bias.unsqueeze(0) + if (dt_bias is not None and dt_bias.dim() == 1) + else dt_bias + ) + _out = out.unsqueeze(1) if out.dim() == 2 else out + # state_batch_indices and dst_state_batch_indices are 1D index arrays; + # do NOT reshape them. + _sbi = state_batch_indices + _dsbi = dst_state_batch_indices + ops.selective_state_update_cpu( + _state, _x, _dt, _A, _B, _C, + _D, _z, _dt_bias, dt_softplus, + _sbi, _dsbi, + null_block_id, _out, + num_accepted_tokens, cu_seqlens, + ) + return _out.squeeze(1) if out.dim() == 2 else _out + + return _selective_state_update_cuda( + state, x, dt, A, B, C, + D=D, z=z, dt_bias=dt_bias, + dt_softplus=dt_softplus, + state_batch_indices=state_batch_indices, + dst_state_batch_indices=dst_state_batch_indices, + null_block_id=null_block_id, + out=out, + num_accepted_tokens=num_accepted_tokens, + cu_seqlens=cu_seqlens, + is_blackwell=is_blackwell, + enable_stochastic_rounding=enable_stochastic_rounding, + cache_philox_rounds=cache_philox_rounds, + ) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 4c93a768b629..4b2ad011f819 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -10,21 +10,23 @@ from einops import rearrange from packaging import version -from vllm.triton_utils import triton +from vllm.model_executor.custom_op import CustomOp +from vllm.triton_utils import HAS_TRITON, triton +from .cpu_fallbacks import _mamba_chunk_scan_combined_fwd_cpu from .ssd_bmm import _bmm_chunk_fwd from .ssd_chunk_scan import _chunk_scan_fwd from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd from .ssd_state_passing import _state_passing_fwd -TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") +TRITON_22 = HAS_TRITON and version.parse(triton.__version__) >= version.parse("2.2.0") def is_int_pow_2(n): return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 -def _mamba_chunk_scan_combined_fwd( +def _mamba_chunk_scan_combined_fwd_cuda( x, dt, A, @@ -154,6 +156,25 @@ def _mamba_chunk_scan_combined_fwd( return states[last_chunk_indices] +@CustomOp.register("mamba_chunk_scan_combined_fwd") +class MambaChunkScanCombinedFwdOp(CustomOp): + def forward_native(self, *args, **kwargs): + return _mamba_chunk_scan_combined_fwd_cpu(*args, **kwargs) + + def forward_cpu(self, *args, **kwargs): + return _mamba_chunk_scan_combined_fwd_cpu(*args, **kwargs) + + def forward_cuda(self, *args, **kwargs): + return _mamba_chunk_scan_combined_fwd_cuda(*args, **kwargs) + + +_mamba_chunk_scan_combined_fwd_op = MambaChunkScanCombinedFwdOp() + + +def _mamba_chunk_scan_combined_fwd(*args, **kwargs): + return _mamba_chunk_scan_combined_fwd_op(*args, **kwargs) + + def mamba_chunk_scan_combined_varlen( x, dt, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 63a79f668edf..a8e7b47dc3c3 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -233,6 +233,12 @@ def dispatch_cpu_unquantized_gemm( layer.cpu_linear = torch.nn.functional.linear return + # Skip CPU GEMM dispatch for non-2D weights (e.g. MoE 3D expert weights). + # These layers are handled by their own specialized methods. + if layer.weight.dim() != 2: + layer.cpu_linear = torch.nn.functional.linear + return + N, K = layer.weight.size() dtype = layer.weight.dtype @@ -290,9 +296,7 @@ def dispatch_cpu_unquantized_gemm( ) # fallback case - layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( - x, weight, bias - ) + layer.cpu_linear = torch.nn.functional.linear def cpu_unquantized_gemm( diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index b99510a66414..b8186d340a8e 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -294,7 +294,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: # TODO: CPU still sets block_size in check_and_update_config. # Move that logic here so block_size is chosen by the backend. - pass + super().update_block_size_for_backend(vllm_config) @classmethod def discover_numa_topology(cls) -> list[list[int]]: