diff --git a/.gitignore b/.gitignore index 088595156b..51268d11a1 100644 --- a/.gitignore +++ b/.gitignore @@ -195,4 +195,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ diff --git a/benchmarks/routines/mamba.py b/benchmarks/routines/mamba.py index d5fffd57c8..7e64468aab 100644 --- a/benchmarks/routines/mamba.py +++ b/benchmarks/routines/mamba.py @@ -43,6 +43,9 @@ from triton_reference.selective_state_update import ( selective_state_update_triton as selective_state_update_triton_reference, ) +from triton_reference.selective_state_update_varlen import ( + selective_state_update_varlen_triton, +) # ============================================================================== @@ -151,6 +154,21 @@ def parse_mamba_args(line, parser): default=False, help="Apply softplus to dt before use.", ) + parser.add_argument( + "--varlen", + action="store_true", + default=False, + help="Use varlen mode: 3D flattened inputs with cu_seqlens, 2D state indices, " + "and dst_state_batch_indices. Requires --cache_steps >= 1 (= max_seqlen).", + ) + parser.add_argument( + "--algorithm", + type=str, + required=False, + default="auto", + choices=["auto", "simple", "vertical", "horizontal"], + help="FlashInfer kernel algorithm. Default: auto (picks best for GPU arch).", + ) parser.add_argument( "--backends", type=str, @@ -196,6 +214,16 @@ def parse_mamba_args(line, parser): f"Supported ratios: {supported_ratios}." ) + if args.varlen: + if args.cache_steps < 1: + raise ValueError( + "--varlen requires --cache_steps >= 1 (specifies max_seqlen)" + ) + if args.cache_steps >= 1 and args.algorithm in ("vertical", "horizontal"): + raise ValueError( + f"MTP/varlen mode only supports 'auto' or 'simple' algorithm, got '{args.algorithm}'" + ) + if args.verbose >= 1: print(f"[INFO] {args = }") return args @@ -245,6 +273,8 @@ def testSelectiveStateUpdate(args): cache_steps = args.cache_steps has_z = args.has_z dt_softplus = args.dt_softplus + is_varlen = args.varlen + algorithm = args.algorithm is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck res = [] @@ -259,11 +289,16 @@ def testSelectiveStateUpdate(args): weight_dtype = dtype_str_to_torch_dtype(args.weight_dtype) ## Done parsing input arguments - ## Determine STP vs MTP mode - is_mtp = cache_steps >= 1 + ## Determine mode + is_mtp = cache_steps >= 1 and not is_varlen T = cache_steps if is_mtp else None - ## Prepare input tensors (mirrors tests/mamba/utils.py::create_test_inputs) + if is_varlen: + n_seqs = batch_size + max_seqlen = cache_steps + total_tokens = n_seqs * max_seqlen + + ## Prepare input tensors ssm_state_cache_size = max(384, batch_size * 10) # State cache: (total_entries, nheads, dim, dstate) - contiguous @@ -271,37 +306,48 @@ def testSelectiveStateUpdate(args): ssm_state_cache_size, nheads, dim, dstate, dtype=state_dtype, device=device ) - # Input x: (batch_size, [T,] nheads, dim) - if T is not None: + if is_varlen: + # Varlen: 3D flattened inputs (total_tokens, ...) + x = torch.randn(total_tokens, nheads, dim, dtype=input_dtype, device=device) + dt_base = torch.randn(total_tokens, nheads, dtype=weight_dtype, device=device) + dt = dt_base.as_strided((total_tokens, nheads, dim), (nheads, 1, 0)) + B = torch.randn(total_tokens, ngroups, dstate, dtype=input_dtype, device=device) + C = torch.randn(total_tokens, ngroups, dstate, dtype=input_dtype, device=device) + z = None + if has_z: + z = torch.randn(total_tokens, nheads, dim, dtype=input_dtype, device=device) + elif T is not None: + # MTP: 4D inputs (batch_size, T, ...) x = torch.randn(batch_size, T, nheads, dim, dtype=input_dtype, device=device) - else: - x = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device) - - # dt: broadcasting across dim (one value per head) - if T is not None: dt_base = torch.randn(batch_size, T, nheads, dtype=weight_dtype, device=device) dt = dt_base.as_strided( (batch_size, T, nheads, dim), (T * nheads, nheads, 1, 0) ) - else: - dt_base = torch.randn(batch_size, nheads, dtype=weight_dtype, device=device) - dt = dt_base.as_strided((batch_size, nheads, dim), (nheads, 1, 0)) - - # A: (nheads, dim, dstate) - negative values, broadcasting (one value per head) - A_base = -torch.rand(nheads, dtype=torch.float32, device=device) - 1.0 - A = A_base.as_strided((nheads, dim, dstate), (1, 0, 0)) - - # B, C: (batch_size, [T,] ngroups, dstate) - if T is not None: B = torch.randn( batch_size, T, ngroups, dstate, dtype=input_dtype, device=device ) C = torch.randn( batch_size, T, ngroups, dstate, dtype=input_dtype, device=device ) + z = None + if has_z: + z = torch.randn( + batch_size, T, nheads, dim, dtype=input_dtype, device=device + ) else: + # STP: 3D inputs (batch_size, ...) + x = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device) + dt_base = torch.randn(batch_size, nheads, dtype=weight_dtype, device=device) + dt = dt_base.as_strided((batch_size, nheads, dim), (nheads, 1, 0)) B = torch.randn(batch_size, ngroups, dstate, dtype=input_dtype, device=device) C = torch.randn(batch_size, ngroups, dstate, dtype=input_dtype, device=device) + z = None + if has_z: + z = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device) + + # A: (nheads, dim, dstate) - negative values, broadcasting (one value per head) + A_base = -torch.rand(nheads, dtype=torch.float32, device=device) - 1.0 + A = A_base.as_strided((nheads, dim, dstate), (1, 0, 0)) # D: (nheads, dim) - broadcasting (one value per head) D_base = torch.randn(nheads, dtype=weight_dtype, device=device) @@ -311,23 +357,30 @@ def testSelectiveStateUpdate(args): dt_bias_base = torch.rand(nheads, dtype=weight_dtype, device=device) - 4.0 dt_bias = dt_bias_base.as_strided((nheads, dim), (1, 0)) - # Slot indices for state batching - slot_idx = torch.randperm(ssm_state_cache_size, dtype=torch.int64, device=device)[ - :batch_size - ] - - # Optional z tensor for gating - z = None - if has_z: - if T is not None: - z = torch.randn( - batch_size, T, nheads, dim, dtype=input_dtype, device=device - ) - else: - z = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device) + # Varlen-specific tensors + if is_varlen: + cu_seqlens = torch.arange( + 0, total_tokens + 1, max_seqlen, device=device, dtype=torch.int32 + ) + perm = torch.randperm(ssm_state_cache_size, device=device).to(torch.int32) + src_indices = perm[: n_seqs * max_seqlen].reshape(n_seqs, max_seqlen) + dst_indices = perm[n_seqs * max_seqlen : 2 * n_seqs * max_seqlen].reshape( + n_seqs, max_seqlen + ) + num_accepted = torch.ones(n_seqs, device=device, dtype=torch.int32) + slot_idx = None + else: + cu_seqlens = None + src_indices = None + dst_indices = None + num_accepted = None + slot_idx = torch.randperm( + ssm_state_cache_size, dtype=torch.int64, device=device + )[:batch_size] if args.verbose >= 2: - print(f"[VVERBOSE] Mode: {'MTP' if is_mtp else 'STP'}") + mode_str = "varlen" if is_varlen else ("MTP" if is_mtp else "STP") + print(f"[VVERBOSE] Mode: {mode_str}") print(f"[VVERBOSE] {state_cache.shape = }, {state_cache.dtype = }") print(f"[VVERBOSE] {x.shape = }, {x.dtype = }") print(f"[VVERBOSE] {dt.shape = }, {dt.dtype = }") @@ -336,7 +389,11 @@ def testSelectiveStateUpdate(args): print(f"[VVERBOSE] {C.shape = }, {C.dtype = }") print(f"[VVERBOSE] {D.shape = }, {D.dtype = }") print(f"[VVERBOSE] {dt_bias.shape = }, {dt_bias.dtype = }") - print(f"[VVERBOSE] {slot_idx.shape = }") + if slot_idx is not None: + print(f"[VVERBOSE] {slot_idx.shape = }") + if cu_seqlens is not None: + print(f"[VVERBOSE] {cu_seqlens.shape = }") + print(f"[VVERBOSE] {src_indices.shape = }, {dst_indices.shape = }") print(f"[VVERBOSE] {has_z = }, {dt_softplus = }") if z is not None: print(f"[VVERBOSE] {z.shape = }, {z.dtype = }") @@ -345,38 +402,79 @@ def testSelectiveStateUpdate(args): triton_cache_steps = cache_steps if cache_steps > 0 else None def run_backend(backend, state, x, dt, A, B, C, D): - if backend == "flashinfer": - return flashinfer.mamba.selective_state_update( - state, - x, - dt, - A, - B, - C, - D, - z=z, - dt_bias=dt_bias, - dt_softplus=dt_softplus, - state_batch_indices=slot_idx, - cache_steps=cache_steps, - ) - elif backend == "triton": - return selective_state_update_triton_reference( - state, - x, - dt, - A, - B, - C, - D, - z=z, - dt_bias=dt_bias, - dt_softplus=dt_softplus, - state_batch_indices=slot_idx, - cache_steps=triton_cache_steps, - ) + if is_varlen: + if backend == "flashinfer": + return flashinfer.mamba.selective_state_update( + state, + x, + dt, + A, + B, + C, + D, + z=z, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + state_batch_indices=src_indices, + algorithm=algorithm, + dst_state_batch_indices=dst_indices, + num_accepted_tokens=num_accepted, + cu_seqlens=cu_seqlens, + cache_steps=cache_steps, + ) + elif backend == "triton": + return selective_state_update_varlen_triton( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + state_batch_indices=src_indices, + dst_state_batch_indices=dst_indices, + num_accepted_tokens=num_accepted, + cu_seqlens=cu_seqlens, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") else: - raise ValueError(f"Unsupported backend: {backend}") + if backend == "flashinfer": + return flashinfer.mamba.selective_state_update( + state, + x, + dt, + A, + B, + C, + D, + z=z, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + state_batch_indices=slot_idx, + cache_steps=cache_steps, + algorithm=algorithm, + ) + elif backend == "triton": + return selective_state_update_triton_reference( + state, + x, + dt, + A, + B, + C, + D, + z=z, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + state_batch_indices=slot_idx, + cache_steps=triton_cache_steps, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") # Reference check: use Triton as golden reference # Save a clean snapshot of state_cache before any benchmarking, because @@ -386,24 +484,46 @@ def run_backend(backend, state, x, dt, A, B, C, D): clean_state_snapshot = state_cache.clone() if run_refcheck else None if run_refcheck: ref_state = clean_state_snapshot.clone() - reference_output = ( - selective_state_update_triton_reference( - ref_state, - x, - dt, - A, - B, - C, - D, - z=z, - dt_bias=dt_bias, - dt_softplus=dt_softplus, - state_batch_indices=slot_idx, - cache_steps=triton_cache_steps, + if is_varlen: + reference_output = ( + selective_state_update_varlen_triton( + ref_state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + state_batch_indices=src_indices, + dst_state_batch_indices=dst_indices, + num_accepted_tokens=num_accepted, + cu_seqlens=cu_seqlens, + ) + .detach() + .clone() + ) + else: + reference_output = ( + selective_state_update_triton_reference( + ref_state, + x, + dt, + A, + B, + C, + D, + z=z, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + state_batch_indices=slot_idx, + cache_steps=triton_cache_steps, + ) + .detach() + .clone() ) - .detach() - .clone() - ) has_reference_output = True # Storage for timing results and outputs @@ -541,6 +661,8 @@ def run_backend(backend, state, x, dt, A, B, C, D): cur_res["weight_dtype"] = str(weight_dtype) cur_res["has_z"] = has_z cur_res["dt_softplus"] = dt_softplus + cur_res["varlen"] = is_varlen + cur_res["algorithm"] = algorithm cur_res["case_tag"] = args.case_tag res.append(cur_res) return res diff --git a/csrc/flashinfer_mamba_binding.cu b/csrc/flashinfer_mamba_binding.cu index 9461cbfa08..fcbc35ac91 100644 --- a/csrc/flashinfer_mamba_binding.cu +++ b/csrc/flashinfer_mamba_binding.cu @@ -24,19 +24,19 @@ void selective_state_update( TensorView state, // (batch, dim, dstate) or (batch, nheads, dim, dstate) TensorView x, // (batch, dim) or (batch, nheads, dim) for single-token // or (batch, T, nheads, dim) for multi-token - TensorView dt, // (batch, dim) or (batch, nheads, dim) for single-token - // or (batch, T, nheads, dim) for multi-token + // or (total_tokens, nheads, dim) for varlen multi-token + TensorView dt, // same layout as x TensorView A, // (dim, dstate) or (nheads, dim, dstate) TensorView B, // (batch, dstate) or (batch, ngroups, dstate) for single-token // or (batch, T, ngroups, dstate) for multi-token - TensorView C, // (batch, dstate) or (batch, ngroups, dstate) for single-token - // or (batch, T, ngroups, dstate) for multi-token + // or (total_tokens, ngroups, dstate) for varlen multi-token + TensorView C, // same layout as B TensorView D, // (dim,) or (nheads, dim) - Optional z, // (batch, dim) or (batch, nheads, dim) for single-token - // or (batch, T, nheads, dim) for multi-token + Optional z, // same layout as x Optional dt_bias, // (dim,) or (nheads, dim) bool dt_softplus, - Optional state_batch_indices, // (batch,) + Optional state_batch_indices, // (batch,) or (N, max_seqlen) + Optional dst_state_batch_indices, // (batch,) or (N, max_seqlen) int64_t pad_slot_id, Optional state_scale, // float32: (state_cache_size, nheads, dim) TensorView output, // same as x @@ -46,6 +46,8 @@ void selective_state_update( Optional intermediate_state_scales, // float32: (batch, cache_steps, nheads, dim) Optional rand_seed, // device-side int64 tensor for Philox rounding int64_t cache_steps, + Optional cu_seqlens, // (N + 1,) + Optional num_accepted_tokens, // (N,) int64_t algorithm); // SSUAlgorithm: 0=auto, 1=simple, 2=vertical, 3=horizontal } // namespace flashinfer::mamba diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index 3a00f5ba9a..52719b3ae0 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -74,12 +74,17 @@ inline void validate_dt_bias_tensor(Optional const& dt_bias, int64_t } inline void validate_state_batch_indices(Optional const& state_batch_indices, - int64_t batch) { + int64_t batch, int64_t max_seqlen = 1) { if (!state_batch_indices.has_value()) return; - CHECK_DIM(1, (*state_batch_indices)); - CHECK_CONTIGUOUS((*state_batch_indices)); - FLASHINFER_CHECK(state_batch_indices.value().size(0) == batch, - "state_batch_indices.shape must be (", batch, ")"); + auto const& sbi = state_batch_indices.value(); + FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ", + sbi.dim(), "D"); + FLASHINFER_CHECK(sbi.size(0) >= batch, "state_batch_indices.size(0) must be >= batch (", batch, + ")"); + if (sbi.dim() == 2) { + FLASHINFER_CHECK(sbi.size(1) >= max_seqlen, + "state_batch_indices.size(1) must be >= max_seqlen (", max_seqlen, ")"); + } } inline void validate_intermediate_state_indices( @@ -149,6 +154,7 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x TensorView const& C, TensorView const& D, Optional z, Optional dt_bias, bool dt_softplus, Optional state_batch_indices, + Optional dst_state_batch_indices, Optional state_scale, int64_t pad_slot_id, Optional out, bool disable_state_update, Optional rand_seed, int64_t algorithm) { @@ -187,6 +193,7 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x validate_A_tensor(A, nheads, dim, dstate); validate_dt_bias_tensor(dt_bias, nheads, dim); validate_state_batch_indices(state_batch_indices, batch); + validate_state_batch_indices(dst_state_batch_indices, batch); // Check B shape and strides CHECK_CUDA(B); @@ -237,6 +244,13 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x // Validate dtype consistency validate_dtype_consistency(state, dt, D, x, B, C, dt_bias, z, out); validate_state_scale(state_scale, state_cache_size, nheads, dim); + if (state_batch_indices.has_value() && dst_state_batch_indices.has_value()) { + DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype(); + DLDataType dst_state_batch_idx_dtype = dst_state_batch_indices.value().dtype(); + FLASHINFER_CHECK(state_batch_idx_dtype.code == dst_state_batch_idx_dtype.code && + state_batch_idx_dtype.bits == dst_state_batch_idx_dtype.bits, + "state_batch_indices and dst_state_batch_indices must have the same dtype"); + } // Initialize params struct SelectiveStateUpdateParams p; @@ -264,7 +278,17 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x } p.state_stride_batch = state.stride(0); if (state_batch_indices.has_value()) { - p.state_batch_indices = const_cast(state_batch_indices.value().data_ptr()); + auto const& sbi = state_batch_indices.value(); + p.state_batch_indices = const_cast(sbi.data_ptr()); + p.state_batch_indices_stride_batch = sbi.stride(0); + p.state_batch_indices_stride_T = (sbi.dim() >= 2) ? sbi.stride(1) : 0; + } + if (dst_state_batch_indices.has_value()) { + auto const& dsbi = dst_state_batch_indices.value(); + CHECK_CUDA(dsbi); + p.dst_state_batch_indices = const_cast(dsbi.data_ptr()); + p.dst_state_batch_indices_stride_batch = dsbi.stride(0); + p.dst_state_batch_indices_stride_T = (dsbi.dim() >= 2) ? dsbi.stride(1) : 0; } if (state_scale.has_value()) { p.state_scale = state_scale.value().data_ptr(); @@ -313,97 +337,138 @@ void run_selective_state_update_mtp( TensorView const& state, TensorView const& x, TensorView const& dt, TensorView const& A, TensorView const& B, TensorView const& C, TensorView const& D, Optional z, Optional dt_bias, bool dt_softplus, Optional state_batch_indices, - Optional state_scale, int64_t pad_slot_id, Optional out, - bool disable_state_update, Optional intermediate_states_buffer, + Optional dst_state_batch_indices, Optional state_scale, + int64_t pad_slot_id, Optional out, bool disable_state_update, + Optional intermediate_states_buffer, Optional intermediate_state_indices, Optional intermediate_state_scales, - Optional rand_seed, int64_t cache_steps, int64_t algorithm) { + Optional rand_seed, int64_t cache_steps, Optional cu_seqlens, + Optional num_accepted_tokens, int64_t algorithm) { + bool const is_varlen = (x.dim() == 3 && cu_seqlens.has_value()); // Extract dimensions from input tensors - auto const batch = x.size(0); - auto const ntokens_mtp = x.size(1); + int64_t batch; + int64_t ntokens_mtp; + auto const state_cache_size = state.size(0); auto const nheads = state.size(1); auto const dim = state.size(2); auto const dstate = state.size(3); - auto const ngroups = B.size(2); - - FLASHINFER_CHECK(state_cache_size >= batch, "state.size(0) must be >= x.size(0)"); - FLASHINFER_CHECK(nheads % ngroups == 0, "nheads must be divisible by ngroups"); // Check x shape and strides CHECK_CUDA(x); - CHECK_DIM(4, x); - FLASHINFER_CHECK(x.size(2) == nheads, "x.size(2) must equal nheads"); - FLASHINFER_CHECK(x.size(3) == dim, "x.size(3) must equal dim"); - CHECK_LAST_DIM_CONTIGUOUS(x); - FLASHINFER_CHECK(x.stride(2) == dim, "x.stride(2) must equal dim, got ", x.stride(2), - " expected ", dim); + if (is_varlen) { + CHECK_DIM(3, x); // x: {total_tokens, nheads, dim} + FLASHINFER_CHECK(x.size(1) == nheads, "x.size(1) must equal nheads"); + FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim"); + CHECK_LAST_DIM_CONTIGUOUS(x); + FLASHINFER_CHECK(x.stride(1) == dim, "x.stride(1) must equal dim"); + batch = cu_seqlens.value().size(0) - 1; + FLASHINFER_CHECK(cache_steps >= 1, + "cache_steps must be >= 1 in varlen mode (specifies max_seqlen)"); + ntokens_mtp = cache_steps; + } else { + CHECK_DIM(4, x); // x: {batch, ntokens_mtp, nheads, dim} + batch = x.size(0); + ntokens_mtp = x.size(1); + FLASHINFER_CHECK(x.size(2) == nheads, "x.size(2) must equal nheads"); + FLASHINFER_CHECK(x.size(3) == dim, "x.size(3) must equal dim"); + CHECK_LAST_DIM_CONTIGUOUS(x); + FLASHINFER_CHECK(x.stride(2) == dim, "x.stride(2) must equal dim, got ", x.stride(2), + " expected ", dim); + } + + auto const ngroups = is_varlen ? B.size(1) : B.size(2); + + FLASHINFER_CHECK(state_cache_size >= batch, "state.size(0) must be >= batch"); + FLASHINFER_CHECK(nheads % ngroups == 0, "nheads must be divisible by ngroups"); // Check dt shape and strides CHECK_CUDA(dt); - CHECK_DIM(4, dt); // dt: {batch, ntokens_mtp, nheads, dim} - FLASHINFER_CHECK(dt.size(0) == batch, "dt.size(0) must equal batch =", batch); - FLASHINFER_CHECK(dt.size(1) == ntokens_mtp, "dt.size(1) must equal ntokens_mtp =", ntokens_mtp); - FLASHINFER_CHECK(dt.size(2) == nheads, "dt.size(2) must equal nheads"); - FLASHINFER_CHECK(dt.size(3) == dim, "dt.size(3) must equal dim"); - FLASHINFER_CHECK(dt.stride(2) == 1, "dt.stride(2) must be 1, got ", dt.stride(2)); - FLASHINFER_CHECK(dt.stride(3) == 0, "dt.stride(3) must be 0 (broadcasted), got ", dt.stride(3)); + if (is_varlen) { + CHECK_DIM(3, dt); // dt: {total_tokens, nheads, dim} + FLASHINFER_CHECK(dt.size(1) == nheads, "dt.size(1) must equal nheads"); + FLASHINFER_CHECK(dt.stride(1) == 1, "dt.stride(1) must be 1"); + FLASHINFER_CHECK(dt.stride(2) == 0, "dt.stride(2) must be 0 (broadcasted)"); + } else { + CHECK_DIM(4, dt); // dt: {batch, ntokens_mtp, nheads, dim} + FLASHINFER_CHECK(dt.size(0) == batch, "dt.size(0) must equal batch"); + FLASHINFER_CHECK(dt.size(1) == ntokens_mtp, "dt.size(1) must equal ntokens_mtp"); + FLASHINFER_CHECK(dt.size(2) == nheads, "dt.size(2) must equal nheads"); + FLASHINFER_CHECK(dt.stride(2) == 1, "dt.stride(2) must be 1"); + FLASHINFER_CHECK(dt.stride(3) == 0, "dt.stride(3) must be 0 (broadcasted)"); + } // Validate common tensors using helper functions validate_state_tensor(state); validate_D_tensor(D, nheads, dim); validate_A_tensor(A, nheads, dim, dstate); validate_dt_bias_tensor(dt_bias, nheads, dim); - validate_state_batch_indices(state_batch_indices, batch); + validate_state_batch_indices(state_batch_indices, batch, ntokens_mtp); + validate_state_batch_indices(dst_state_batch_indices, batch, ntokens_mtp); // Check B shape and strides CHECK_CUDA(B); - CHECK_DIM(4, B); // B: {batch, ntokens_mtp, ngroups, dstate} - FLASHINFER_CHECK(B.size(0) == batch, "B.size(0) must equal batch =", batch); - FLASHINFER_CHECK(B.size(1) == ntokens_mtp, "B.size(1) must equal ntokens_mtp =", ntokens_mtp); - FLASHINFER_CHECK(B.size(2) == ngroups, "B.size(2) must equal ngroups =", ngroups); - FLASHINFER_CHECK(B.size(3) == dstate, "B.size(3) must equal dstate =", dstate); + if (is_varlen) { + CHECK_DIM(3, B); // B: {total_tokens, ngroups, dstate} + FLASHINFER_CHECK(B.size(1) == ngroups, "B.size(1) must equal ngroups"); + FLASHINFER_CHECK(B.size(2) == dstate, "B.size(2) must equal dstate"); + } else { + CHECK_DIM(4, B); // B: {batch, ntokens_mtp, ngroups, dstate} + FLASHINFER_CHECK(B.size(0) == batch, "B.size(0) must equal batch"); + FLASHINFER_CHECK(B.size(1) == ntokens_mtp, "B.size(1) must equal ntokens_mtp"); + FLASHINFER_CHECK(B.size(2) == ngroups, "B.size(2) must equal ngroups"); + FLASHINFER_CHECK(B.size(3) == dstate, "B.size(3) must equal dstate"); + } CHECK_LAST_DIM_CONTIGUOUS(B); - FLASHINFER_CHECK(B.stride(2) == dstate, "B.stride(2) must equal dstate, got ", B.stride(2), - " expected ", dstate); // Check C shape and strides CHECK_CUDA(C); - CHECK_DIM(4, C); // C: {batch, ntokens_mtp, ngroups, dstate} - FLASHINFER_CHECK(C.size(0) == batch, "C.size(0) must equal batch"); - FLASHINFER_CHECK(C.size(1) == ntokens_mtp, "C.size(1) must equal ntokens_mtp =", ntokens_mtp); - FLASHINFER_CHECK(C.size(2) == ngroups, "C.size(2) must equal ngroups"); - FLASHINFER_CHECK(C.size(3) == dstate, "C.size(3) must equal dstate"); + if (is_varlen) { + CHECK_DIM(3, C); // C: {total_tokens, ngroups, dstate} + FLASHINFER_CHECK(C.size(1) == ngroups, "C.size(1) must equal ngroups"); + FLASHINFER_CHECK(C.size(2) == dstate, "C.size(2) must equal dstate"); + } else { + CHECK_DIM(4, C); // C: {batch, ntokens_mtp, ngroups, dstate} + FLASHINFER_CHECK(C.size(0) == batch, "C.size(0) must equal batch"); + FLASHINFER_CHECK(C.size(1) == ntokens_mtp, "C.size(1) must equal ntokens_mtp"); + FLASHINFER_CHECK(C.size(2) == ngroups, "C.size(2) must equal ngroups"); + FLASHINFER_CHECK(C.size(3) == dstate, "C.size(3) must equal dstate"); + } CHECK_LAST_DIM_CONTIGUOUS(C); - FLASHINFER_CHECK(C.stride(2) == dstate, "C.stride(2) must equal dstate, got ", C.stride(2), - " expected ", dstate); // Optional z check if (z.has_value()) { auto& z_tensor = z.value(); CHECK_CUDA(z_tensor); - CHECK_DIM(4, z_tensor); // z: {batch, ntokens_mtp, nheads, dim} - FLASHINFER_CHECK(z_tensor.size(0) == batch, "z.size(0) must equal batch"); - FLASHINFER_CHECK(z_tensor.size(1) == ntokens_mtp, "z.size(1) must equal ntokens_mtp"); - FLASHINFER_CHECK(z_tensor.size(2) == nheads, "z.size(2) must equal nheads"); - FLASHINFER_CHECK(z_tensor.size(3) == dim, "z.size(3) must equal dim"); + if (is_varlen) { + CHECK_DIM(3, z_tensor); // z: {total_tokens, nheads, dim} + FLASHINFER_CHECK(z_tensor.size(1) == nheads, "z.size(1) must equal nheads"); + FLASHINFER_CHECK(z_tensor.size(2) == dim, "z.size(2) must equal dim"); + } else { + CHECK_DIM(4, z_tensor); // z: {batch, ntokens_mtp, nheads, dim} + FLASHINFER_CHECK(z_tensor.size(0) == batch, "z.size(0) must equal batch"); + FLASHINFER_CHECK(z_tensor.size(1) == ntokens_mtp, "z.size(1) must equal ntokens_mtp"); + FLASHINFER_CHECK(z_tensor.size(2) == nheads, "z.size(2) must equal nheads"); + FLASHINFER_CHECK(z_tensor.size(3) == dim, "z.size(3) must equal dim"); + } CHECK_LAST_DIM_CONTIGUOUS(z_tensor); - FLASHINFER_CHECK(z_tensor.stride(2) == dim, "z.stride(2) must equal dim, got ", - z_tensor.stride(2), " expected ", dim); } // Check output tensor if provided if (out.has_value()) { auto& output = out.value(); CHECK_CUDA(output); - CHECK_DIM(4, output); - FLASHINFER_CHECK(output.size(0) == batch, "out.size(0) must equal batch = ", batch); - FLASHINFER_CHECK(output.size(1) == ntokens_mtp, - "out.size(1) must equal ntokens_mtp = ", ntokens_mtp); - FLASHINFER_CHECK(output.size(2) == nheads, "out.size(2) must equal nheads = ", nheads); - FLASHINFER_CHECK(output.size(3) == dim, "out.size(3) must equal dim = ", dim); CHECK_LAST_DIM_CONTIGUOUS(output); - FLASHINFER_CHECK(output.stride(2) == dim, "out.stride(2) = ", output.stride(2), - " must equal dim = ", dim); + if (is_varlen) { + CHECK_DIM(3, output); // out: {total_tokens, nheads, dim} + FLASHINFER_CHECK(output.size(1) == nheads, "out.size(1) must equal nheads"); + FLASHINFER_CHECK(output.size(2) == dim, "out.size(2) must equal dim"); + } else { + CHECK_DIM(4, output); // out: {batch, ntokens_mtp, nheads, dim} + FLASHINFER_CHECK(output.size(0) == batch, "out.size(0) must equal batch"); + FLASHINFER_CHECK(output.size(1) == ntokens_mtp, "out.size(1) must equal ntokens_mtp"); + FLASHINFER_CHECK(output.size(2) == nheads, "out.size(2) must equal nheads"); + FLASHINFER_CHECK(output.size(3) == dim, "out.size(3) must equal dim"); + } } // Validate dtype consistency @@ -412,7 +477,7 @@ void run_selective_state_update_mtp( validate_intermediate_states_buffer(intermediate_states_buffer); validate_state_scale(state_scale, state_cache_size, nheads, dim); - // Validate that state_batch_indices and intermediate_state_indices have the same dtype + // Validate that index tensors have consistent dtypes if (state_batch_indices.has_value() && intermediate_state_indices.has_value()) { DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype(); DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype(); @@ -420,6 +485,13 @@ void run_selective_state_update_mtp( state_batch_idx_dtype.bits == intermediate_idx_dtype.bits, "state_batch_indices and intermediate_state_indices must have the same dtype"); } + if (state_batch_indices.has_value() && dst_state_batch_indices.has_value()) { + DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype(); + DLDataType dst_state_batch_idx_dtype = dst_state_batch_indices.value().dtype(); + FLASHINFER_CHECK(state_batch_idx_dtype.code == dst_state_batch_idx_dtype.code && + state_batch_idx_dtype.bits == dst_state_batch_idx_dtype.bits, + "state_batch_indices and dst_state_batch_indices must have the same dtype"); + } // Validate cache_steps is non-negative FLASHINFER_CHECK(cache_steps >= 0, "cache_steps must be non-negative, got ", cache_steps); @@ -455,19 +527,59 @@ void run_selective_state_update_mtp( p.state_stride_batch = state.stride(0); // Copy MTP strides - p.x_stride_mtp = x.stride(1); - p.dt_stride_mtp = dt.stride(1); - p.B_stride_mtp = B.stride(1); - p.C_stride_mtp = C.stride(1); - if (out.has_value()) { - p.out_stride_mtp = out.value().stride(1); - } else { + if (is_varlen) { + p.x_stride_mtp = 0; + p.dt_stride_mtp = 0; + p.B_stride_mtp = 0; + p.C_stride_mtp = 0; p.out_stride_mtp = 0; + } else { + p.x_stride_mtp = x.stride(1); + p.dt_stride_mtp = dt.stride(1); + p.B_stride_mtp = B.stride(1); + p.C_stride_mtp = C.stride(1); + if (out.has_value()) { + p.out_stride_mtp = out.value().stride(1); + } else { + p.out_stride_mtp = 0; + } } if (state_batch_indices.has_value()) { - p.state_batch_indices = const_cast(state_batch_indices.value().data_ptr()); + auto const& sbi = state_batch_indices.value(); + p.state_batch_indices = const_cast(sbi.data_ptr()); + p.state_batch_indices_stride_batch = sbi.stride(0); + p.state_batch_indices_stride_T = (sbi.dim() >= 2) ? sbi.stride(1) : 0; } + if (dst_state_batch_indices.has_value()) { + auto const& dsbi = dst_state_batch_indices.value(); + CHECK_CUDA(dsbi); + p.dst_state_batch_indices = const_cast(dsbi.data_ptr()); + p.dst_state_batch_indices_stride_batch = dsbi.stride(0); + p.dst_state_batch_indices_stride_T = (dsbi.dim() >= 2) ? dsbi.stride(1) : 0; + } + if (cu_seqlens.has_value()) { + auto const& cs = cu_seqlens.value(); + CHECK_CUDA(cs); + CHECK_DIM(1, cs); + CHECK_CONTIGUOUS(cs); + FLASHINFER_CHECK(cs.size(0) == batch + 1, "cu_seqlens.size(0) must equal n_sequences + 1 (", + batch + 1, ")"); + p.cu_seqlens = const_cast(cs.data_ptr()); + } + if (num_accepted_tokens.has_value()) { + auto const& nat = num_accepted_tokens.value(); + CHECK_CUDA(nat); + CHECK_DIM(1, nat); + CHECK_CONTIGUOUS(nat); + FLASHINFER_CHECK(nat.size(0) >= batch, "num_accepted_tokens.size(0) must be >= n_sequences (", + batch, ")"); + FLASHINFER_CHECK(state_batch_indices.has_value(), + "state_batch_indices is required when num_accepted_tokens is provided"); + p.num_accepted_tokens = const_cast(nat.data_ptr()); + } + FLASHINFER_CHECK(!(dst_state_batch_indices.has_value() && intermediate_states_buffer.has_value()), + "dst_state_batch_indices and intermediate_states_buffer are mutually exclusive"); if (state_scale.has_value()) { p.state_scale = state_scale.value().data_ptr(); p.state_scale_stride_batch = state_scale.value().stride(0); @@ -521,7 +633,7 @@ void run_selective_state_update_mtp( if (z.has_value()) { p.z = const_cast(z.value().data_ptr()); p.z_stride_batch = z.value().stride(0); - p.z_stride_mtp = z.value().stride(1); + p.z_stride_mtp = is_varlen ? 0 : z.value().stride(1); } p.A = const_cast(A.data_ptr()); p.B = const_cast(B.data_ptr()); @@ -543,19 +655,23 @@ void run_selective_state_update_mtp( void selective_state_update( TensorView state, TensorView x, TensorView dt, TensorView A, TensorView B, TensorView C, TensorView D, Optional z, Optional dt_bias, bool dt_softplus, - Optional state_batch_indices, int64_t pad_slot_id, Optional state_scale, - TensorView output, bool disable_state_update, Optional intermediate_states_buffer, + Optional state_batch_indices, Optional dst_state_batch_indices, + int64_t pad_slot_id, Optional state_scale, TensorView output, + bool disable_state_update, Optional intermediate_states_buffer, Optional intermediate_state_indices, Optional intermediate_state_scales, - Optional rand_seed, int64_t cache_steps, int64_t algorithm) { - if (x.dim() == 3) { + Optional rand_seed, int64_t cache_steps, Optional cu_seqlens, + Optional num_accepted_tokens, int64_t algorithm) { + bool const has_cu_seqlens = cu_seqlens.has_value(); + if (x.dim() == 3 && !has_cu_seqlens) { run_selective_state_update_stp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, - state_batch_indices, state_scale, pad_slot_id, output, - disable_state_update, rand_seed, algorithm); - } else if (x.dim() == 4) { + state_batch_indices, dst_state_batch_indices, state_scale, + pad_slot_id, output, disable_state_update, rand_seed, algorithm); + } else if (x.dim() == 4 || (x.dim() == 3 && has_cu_seqlens)) { run_selective_state_update_mtp( - state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices, state_scale, - pad_slot_id, output, disable_state_update, intermediate_states_buffer, - intermediate_state_indices, intermediate_state_scales, rand_seed, cache_steps, algorithm); + state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices, + dst_state_batch_indices, state_scale, pad_slot_id, output, disable_state_update, + intermediate_states_buffer, intermediate_state_indices, intermediate_state_scales, + rand_seed, cache_steps, cu_seqlens, num_accepted_tokens, algorithm); } else { FLASHINFER_CHECK(false, "x must have 3 dimensions (single-token) or 4 dimensions (multi-token), got ", diff --git a/csrc/selective_state_update_customize_config.jinja b/csrc/selective_state_update_customize_config.jinja index 4e6fb6caa4..712afbc4d1 100644 --- a/csrc/selective_state_update_customize_config.jinja +++ b/csrc/selective_state_update_customize_config.jinja @@ -8,6 +8,8 @@ using input_t = {{ input_dtype }}; using weight_t = {{ weight_dtype }}; using matrixA_t = {{ matrixA_dtype }}; using stateIndex_t = {{ stateIndex_dtype }}; +using cuSeqlensIndex_t = {{ cu_seqlens_dtype }}; +using numAcceptedIndex_t = {{ num_accepted_tokens_dtype }}; // Type for block-scale decode factors (e.g. float, __half). // void = no scaling (state_t is used as-is). using state_scale_t = {{ state_scale_type }}; diff --git a/flashinfer/aot.py b/flashinfer/aot.py index fc46938231..65d3167f79 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -588,22 +588,36 @@ def gen_all_modules( _ssu_dims = [64] _ssu_dstates = [128] _ssu_ntokens = [1, 4, 6, 8] - for dtype_combo, dim, dstate, ntokens in product( - _ssu_dtype_combos, _ssu_dims, _ssu_dstates, _ssu_ntokens + _ssu_cu_seqlens_dtypes = [torch.int32, torch.int64] + _ssu_num_accepted_dtypes = [torch.int32, torch.int64] + for dtype_combo, dim, dstate, ntokens, cs_dtype, na_dtype in product( + _ssu_dtype_combos, + _ssu_dims, + _ssu_dstates, + _ssu_ntokens, + _ssu_cu_seqlens_dtypes, + _ssu_num_accepted_dtypes, ): jit_specs.append( # false positive: mypy can't resolve the signature because flashinfer.jit deps (filelock etc.) # are absent in mypy's isolated env, causing it to infer an incorrect function signature - gen_selective_state_update_module(*dtype_combo, dim, dstate, ntokens) # type: ignore[call-arg] + gen_selective_state_update_module( + *dtype_combo, dim, dstate, ntokens, cs_dtype, na_dtype + ) # type: ignore[call-arg] ) if has_sm90 or has_sm100: - for dtype_combo, dim, dstate, ntokens in product( - _ssu_dtype_combos, _ssu_dims, _ssu_dstates, _ssu_ntokens + for dtype_combo, dim, dstate, ntokens, cs_dtype, na_dtype in product( + _ssu_dtype_combos, + _ssu_dims, + _ssu_dstates, + _ssu_ntokens, + _ssu_cu_seqlens_dtypes, + _ssu_num_accepted_dtypes, ): jit_specs.append( # same false positive as above gen_selective_state_update_sm90_module( # type: ignore[call-arg] - *dtype_combo, dim, dstate, ntokens + *dtype_combo, dim, dstate, ntokens, cs_dtype, na_dtype ) ) jit_specs.append(gen_trtllm_utils_module()) diff --git a/flashinfer/jit/mamba/selective_state_update.py b/flashinfer/jit/mamba/selective_state_update.py index 15b1e13094..e26c24ea90 100644 --- a/flashinfer/jit/mamba/selective_state_update.py +++ b/flashinfer/jit/mamba/selective_state_update.py @@ -56,6 +56,8 @@ def get_selective_state_update_uri( dim: int, dstate: int, ntokens_mtp: int, + cu_seqlens_dtype: torch.dtype, + num_accepted_tokens_dtype: torch.dtype, philox_rounds: int = 0, ) -> str: s = _filename_safe_dtype_map @@ -64,6 +66,7 @@ def get_selective_state_update_uri( f"s_{s[state_dtype]}_i_{s[input_dtype]}_w_{s[weight_dtype]}_" f"a_{s[matrixA_dtype]}_si_{s[stateIndex_dtype]}_" f"d_{dim}_ds_{dstate}_nt_{ntokens_mtp}" + f"_cs_{s[cu_seqlens_dtype]}_na_{s[num_accepted_tokens_dtype]}" ) if state_scale_dtype is not None: uri += f"_sc_{s[state_scale_dtype]}" @@ -83,6 +86,8 @@ def _gen_module( dim: int, dstate: int, ntokens_mtp: int, + cu_seqlens_dtype: torch.dtype, + num_accepted_tokens_dtype: torch.dtype, philox_rounds: int = 0, extra_cuda_cflags: list = None, ) -> JitSpec: @@ -104,6 +109,8 @@ def _gen_module( weight_dtype=_dtype_map[weight_dtype], matrixA_dtype=_dtype_map[matrixA_dtype], stateIndex_dtype=_dtype_map[stateIndex_dtype], + cu_seqlens_dtype=_dtype_map[cu_seqlens_dtype], + num_accepted_tokens_dtype=_dtype_map[num_accepted_tokens_dtype], dim=dim, dstate=dstate, ntokens_mtp=ntokens_mtp, @@ -143,6 +150,8 @@ def gen_selective_state_update_module( dim: int, dstate: int, ntokens_mtp: int, + cu_seqlens_dtype: torch.dtype, + num_accepted_tokens_dtype: torch.dtype, philox_rounds: int = 0, ) -> JitSpec: uri = get_selective_state_update_uri( @@ -155,6 +164,8 @@ def gen_selective_state_update_module( dim, dstate, ntokens_mtp, + cu_seqlens_dtype, + num_accepted_tokens_dtype, philox_rounds, ) return _gen_module( @@ -168,6 +179,8 @@ def gen_selective_state_update_module( dim, dstate, ntokens_mtp, + cu_seqlens_dtype, + num_accepted_tokens_dtype, philox_rounds=philox_rounds, extra_cuda_cflags=["-lineinfo"], ) @@ -183,6 +196,8 @@ def gen_selective_state_update_sm90_module( dim: int, dstate: int, ntokens_mtp: int, + cu_seqlens_dtype: torch.dtype, + num_accepted_tokens_dtype: torch.dtype, philox_rounds: int = 0, ) -> JitSpec: uri = ( @@ -196,6 +211,8 @@ def gen_selective_state_update_sm90_module( dim, dstate, ntokens_mtp, + cu_seqlens_dtype, + num_accepted_tokens_dtype, philox_rounds, ) + "_sm90" @@ -216,6 +233,8 @@ def gen_selective_state_update_sm90_module( dim, dstate, ntokens_mtp, + cu_seqlens_dtype, + num_accepted_tokens_dtype, philox_rounds=philox_rounds, extra_cuda_cflags=nvcc_flags, ) diff --git a/flashinfer/mamba/selective_state_update.py b/flashinfer/mamba/selective_state_update.py index 92447a923e..9427f2c610 100644 --- a/flashinfer/mamba/selective_state_update.py +++ b/flashinfer/mamba/selective_state_update.py @@ -37,6 +37,8 @@ def _get_module( dim: int, dstate: int, ntokens_mtp: int, + cu_seqlens_dtype: torch.dtype, + num_accepted_tokens_dtype: torch.dtype, sm_major: int, state_scale_dtype: Optional[torch.dtype] = None, philox_rounds: int = 0, @@ -51,6 +53,8 @@ def _get_module( dim, dstate, ntokens_mtp, + cu_seqlens_dtype, + num_accepted_tokens_dtype, philox_rounds, ) if sm_major >= 9: @@ -69,6 +73,8 @@ def get_selective_state_update_module( dim: int, dstate: int, ntokens_mtp: int, + cu_seqlens_dtype: torch.dtype, + num_accepted_tokens_dtype: torch.dtype, state_scale_dtype: Optional[torch.dtype] = None, philox_rounds: int = 0, ): @@ -82,6 +88,8 @@ def get_selective_state_update_module( dim, dstate, ntokens_mtp, + cu_seqlens_dtype, + num_accepted_tokens_dtype, major, state_scale_dtype, philox_rounds, @@ -112,6 +120,9 @@ def selective_state_update( philox_rounds: int = 10, cache_steps: int = 0, algorithm: str = "auto", + dst_state_batch_indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Selective state update operation for Mamba layers (the generation phase). @@ -120,34 +131,36 @@ def selective_state_update( state : torch.Tensor State tensor with shape (state_cache_size, dim, dstate) or (state_cache_size, nheads, dim, dstate) x : torch.Tensor - Input tensor with shape (batch, dim) or (batch, nheads, dim) for single-token - or (batch, T, nheads, dim) for multi-token + Input tensor with shape (batch, dim) or (batch, nheads, dim) for single-token, + (batch, T, nheads, dim) for multi-token, + or (total_tokens, nheads, dim) for varlen multi-token (with cu_seqlens) dt : torch.Tensor - Delta time tensor with shape (batch, dim) or (batch, nheads, dim) for single-token - or (batch, T, nheads, dim) for multi-token + Delta time tensor, same layout as x A : torch.Tensor A matrix with shape (dim, dstate) or (nheads, dim, dstate) B : torch.Tensor - B matrix with shape (batch, dstate) or (batch, ngroups, dstate) for single-token - or (batch, T, ngroups, dstate) for multi-token + B matrix with shape (batch, dstate) or (batch, ngroups, dstate) for single-token, + (batch, T, ngroups, dstate) for multi-token, + or (total_tokens, ngroups, dstate) for varlen multi-token C : torch.Tensor - C matrix with shape (batch, dstate) or (batch, ngroups, dstate) for single-token - or (batch, T, ngroups, dstate) for multi-token + C matrix, same layout as B D : torch.Tensor D vector with shape (dim,) or (nheads, dim) z : Optional[torch.Tensor] - Optional z tensor with shape (batch, dim) or (batch, nheads, dim) for single-token - or (batch, T, nheads, dim) for multi-token + Optional z tensor, same layout as x dt_bias : Optional[torch.Tensor] Optional dt bias with shape (dim,) or (nheads, dim) dt_softplus : bool Whether to apply softplus to dt state_batch_indices : Optional[torch.Tensor] - Optional batch indices for cache processing with shape (batch,) + Batch indices for state cache reading. Shape (batch,) or (N, max_seqlen). + For speculative decoding with num_accepted_tokens, must be 2D. + dst_state_batch_indices : Optional[torch.Tensor] + Destination indices for state cache writing. Shape (batch,) or (N, max_seqlen). + When provided, state is read from state_batch_indices and written to + dst_state_batch_indices (enables separate read/write state slots). pad_slot_id : int - If state_batch_indices is passed, lets the kernel identify padded entries - that will not be processed. For example: state_batch_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at indices 0 and 3 + Sentinel value for padded entries in state_batch_indices state_scale : Optional[torch.Tensor] Optional float32 scale tensor with shape (state_cache_size, nheads, dim) for int16 state quantization with block scaling @@ -162,31 +175,28 @@ def selective_state_update( Optional indices mapping batch elements to intermediate state buffer positions with shape (batch,) rand_seed : Optional[torch.Tensor] - Optional single-element int64 CUDA tensor for stochastic rounding seed (Philox-4x32 PRNG). - Using a device-side tensor (rather than a host integer) ensures CUDA graph compatibility, - since the graph captures the pointer and the seed value can be updated between replays. - When provided, state values are stochastically rounded before storing to fp16. - When None, no stochastic rounding is applied (regardless of philox_rounds). - Cannot be used together with state_scale. + Optional single-element int64 CUDA tensor for stochastic rounding seed. philox_rounds : int - Number of Philox-4x32 PRNG rounds for stochastic rounding (default 10, - matching Triton's tl.randint). Only effective when rand_seed is not None; - ignored otherwise. Must be non-negative. + Number of Philox-4x32 PRNG rounds for stochastic rounding (default 10). cache_steps : int - Number of steps/tokens to cache for speculative decoding + Number of steps/tokens to cache for speculative decoding. + For varlen mode (cu_seqlens provided), this specifies max_seqlen. + cu_seqlens : Optional[torch.Tensor] + Cumulative sequence lengths with shape (N + 1,). + When provided, inputs are in varlen format (tokens flattened into batch dim). + num_accepted_tokens : Optional[torch.Tensor] + Number of accepted tokens per sequence with shape (N,). + Determines which state to read as initial state for each sequence. algorithm : str - Algorithm to use: "auto" (default, picks the best kernel based on GPU arch, - data types, and problem size), "simple" (all GPUs), "vertical" and "horizontal" - (SM90+ only). MTP mode only supports "auto" or "simple". + Algorithm to use: "auto", "simple", "vertical", "horizontal" Returns ------- output : torch.Tensor - Output tensor with shape (batch, dim) or (batch, nheads, dim) for single-token - or (batch, T, nheads, dim) for multi-token + Output tensor with same shape as x """ - # Determine if we're in multi-token mode (more than 1 token) - is_mtp = cache_steps >= 1 + is_varlen = cu_seqlens is not None and x.dim() == 3 + is_mtp = cache_steps >= 1 and not is_varlen if state.dim() == 3: state = state.unsqueeze(1) @@ -197,35 +207,37 @@ def selective_state_update( if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) - # Handle x, dt, B, C, z dimensions based on mode - # For single-token: 2D -> 3D (batch, nheads, dim) - # For multi-token: 3D -> 4D (batch, T, nheads, dim) - if x.dim() == 2: - x = x.unsqueeze(1) - if is_mtp and x.dim() == 3: - # Add T dimension for MTP mode: (batch, nheads, dim) -> (batch, T, nheads, dim) - x = x.unsqueeze(1) + if not is_varlen: + # Handle x, dt, B, C, z dimensions based on mode + # For single-token: 2D -> 3D (batch, nheads, dim) + # For multi-token: 3D -> 4D (batch, T, nheads, dim) + if x.dim() == 2: + x = x.unsqueeze(1) + if is_mtp and x.dim() == 3: + # Add T dimension for MTP mode: (batch, nheads, dim) -> (batch, T, nheads, dim) + x = x.unsqueeze(1) - if dt.dim() == 2: - dt = dt.unsqueeze(1) - if is_mtp and dt.dim() == 3: - dt = dt.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if is_mtp and dt.dim() == 3: + dt = dt.unsqueeze(1) - if B.dim() == 2: - B = B.unsqueeze(1) - if is_mtp and B.dim() == 3: - B = B.unsqueeze(1) + if B.dim() == 2: + B = B.unsqueeze(1) + if is_mtp and B.dim() == 3: + B = B.unsqueeze(1) - if C.dim() == 2: - C = C.unsqueeze(1) - if is_mtp and C.dim() == 3: - C = C.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if is_mtp and C.dim() == 3: + C = C.unsqueeze(1) + + if z is not None: + if z.dim() == 2: + z = z.unsqueeze(1) + if is_mtp and z.dim() == 3: + z = z.unsqueeze(1) - if z is not None: - if z.dim() == 2: - z = z.unsqueeze(1) - if is_mtp and z.dim() == 3: - z = z.unsqueeze(1) # Normalize state_scale to 3D: (state_cache_size, nheads, dim) if state_scale is not None and state_scale.dim() == 4 and state_scale.size(-1) == 1: state_scale = state_scale.squeeze(-1) @@ -263,13 +275,20 @@ def selective_state_update( stateIndex_dtype = torch.int32 if state_batch_indices is not None: stateIndex_dtype = state_batch_indices.dtype + elif dst_state_batch_indices is not None: + stateIndex_dtype = dst_state_batch_indices.dtype elif intermediate_state_indices is not None: stateIndex_dtype = intermediate_state_indices.dtype # Extract dim/dstate/ntokens for JIT specialization dim = state.size(2) dstate = state.size(3) - ntokens_mtp = x.size(1) if x.dim() == 4 else 1 + if is_varlen: + ntokens_mtp = cache_steps + elif x.dim() == 4: + ntokens_mtp = x.size(1) + else: + ntokens_mtp = 1 if algorithm == "auto": algorithm_int = 0 @@ -294,6 +313,7 @@ def selective_state_update( dt_bias, dt_softplus, state_batch_indices, + dst_state_batch_indices, pad_slot_id, state_scale, output, @@ -303,6 +323,8 @@ def selective_state_update( intermediate_state_scales, rand_seed, cache_steps, + cu_seqlens, + num_accepted_tokens, algorithm_int, philox_rounds, state.dtype, @@ -339,6 +361,7 @@ def _selective_state_update( dt_bias: Optional[torch.Tensor], dt_softplus: bool, state_batch_indices: Optional[torch.Tensor], + dst_state_batch_indices: Optional[torch.Tensor], pad_slot_id: int, state_scale: Optional[torch.Tensor], output: torch.Tensor, @@ -348,6 +371,8 @@ def _selective_state_update( intermediate_state_scales: Optional[torch.Tensor], rand_seed: Optional[torch.Tensor], cache_steps: int, + cu_seqlens: Optional[torch.Tensor], + num_accepted_tokens: Optional[torch.Tensor], algorithm: int, philox_rounds: int, state_dtype: torch.dtype, @@ -370,6 +395,8 @@ def _selective_state_update( dim, dstate, ntokens_mtp, + cu_seqlens.dtype if cu_seqlens is not None else torch.int32, + num_accepted_tokens.dtype if num_accepted_tokens is not None else torch.int64, state_scale_dtype=state_scale.dtype if state_scale is not None else None, philox_rounds=philox_rounds, ).selective_state_update( @@ -384,6 +411,7 @@ def _selective_state_update( dt_bias, dt_softplus, state_batch_indices, + dst_state_batch_indices, pad_slot_id, state_scale, output, @@ -393,6 +421,8 @@ def _selective_state_update( intermediate_state_scales, rand_seed, cache_steps, + cu_seqlens, + num_accepted_tokens, algorithm, ) @@ -410,6 +440,7 @@ def _selective_state_update_fake( dt_bias: Optional[torch.Tensor], dt_softplus: bool, state_batch_indices: Optional[torch.Tensor], + dst_state_batch_indices: Optional[torch.Tensor], pad_slot_id: int, state_scale: Optional[torch.Tensor], output: torch.Tensor, @@ -419,6 +450,8 @@ def _selective_state_update_fake( intermediate_state_scales: Optional[torch.Tensor], rand_seed: Optional[torch.Tensor], cache_steps: int, + cu_seqlens: Optional[torch.Tensor], + num_accepted_tokens: Optional[torch.Tensor], algorithm: int, philox_rounds: int, state_dtype: torch.dtype, diff --git a/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh index 37c96a302e..c097855a97 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh @@ -31,7 +31,7 @@ struct SharedStorageSimple { std::conditional_t state_scale[STATE_ROWS]; }; -// Grid: (batch, nheads, cdiv(DIM, ROWS_PER_BLOCK)) +// Grid: (batch_or_n_sequences, nheads, cdiv(DIM, ROWS_PER_BLOCK)) // When ROWS_PER_BLOCK == DIM, degenerates to the non-tiled case (blockIdx.z == 0 always). template (params.state_batch_indices); auto const* __restrict__ intermediate_state_indices = reinterpret_cast(params.intermediate_state_indices); + auto const* __restrict__ cu_seqlens = + reinterpret_cast(params.cu_seqlens); + auto const* __restrict__ num_accepted_tokens = + reinterpret_cast(params.num_accepted_tokens); + auto const* __restrict__ dst_state_batch_indices = + reinterpret_cast(params.dst_state_batch_indices); bool const dt_softplus = params.dt_softplus; int const nheads = params.nheads; int const ngroups = params.ngroups; - auto const batch = blockIdx.x; + auto const seq_idx = blockIdx.x; auto const head = blockIdx.y; auto const dim_offset = blockIdx.z * ROWS_PER_BLOCK; auto const group = head / (nheads / ngroups); auto lane = threadIdx.x % warpSize; auto warp = threadIdx.y; + int bos; + int seq_len; + bool const has_cu_seqlens = (cu_seqlens != nullptr); + if (has_cu_seqlens) { + bos = __ldg(&cu_seqlens[seq_idx]); + int eos = __ldg(&cu_seqlens[seq_idx + 1]); + seq_len = eos - bos; + if (seq_len <= 0) return; + } else { + bos = 0; + seq_len = TOKENS_MTP; + } + + int init_token_idx = 0; + if (num_accepted_tokens) { + int num_accepted = __ldg(&num_accepted_tokens[seq_idx]); + init_token_idx = max(num_accepted - 1, 0); + } + // State scale pointer (only used when scaleState == true) [[maybe_unused]] auto* __restrict__ state_scale = reinterpret_cast(params.state_scale); @@ -73,15 +98,47 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams // Load device-side Philox seed once into a register [[maybe_unused]] int64_t const rand_seed = params.rand_seed ? *params.rand_seed : 0; - auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; + int64_t state_batch; + if (state_batch_indices) { + state_batch = static_cast( + state_batch_indices[seq_idx * params.state_batch_indices_stride_batch + + init_token_idx * params.state_batch_indices_stride_T]); + } else { + state_batch = static_cast(seq_idx); + } auto const intermediate_cache_idx = - intermediate_state_indices ? intermediate_state_indices[batch] : state_batch; + intermediate_state_indices ? intermediate_state_indices[seq_idx] : state_batch; auto const state_ptr_offset = state_batch * params.state_stride_batch + head * DIM * DSTATE; state += state_ptr_offset; if constexpr (scaleState) { state_scale += state_batch * params.state_scale_stride_batch + head * DIM; } + int64_t const x_base = has_cu_seqlens ? (int64_t)bos * params.x_stride_batch + : (int64_t)seq_idx * params.x_stride_batch; + int64_t const x_tstride = has_cu_seqlens ? params.x_stride_batch : params.x_stride_mtp; + + int64_t const dt_base = has_cu_seqlens ? (int64_t)bos * params.dt_stride_batch + : (int64_t)seq_idx * params.dt_stride_batch; + int64_t const dt_tstride = has_cu_seqlens ? params.dt_stride_batch : params.dt_stride_mtp; + + int64_t const B_base = has_cu_seqlens ? (int64_t)bos * params.B_stride_batch + : (int64_t)seq_idx * params.B_stride_batch; + int64_t const B_tstride = has_cu_seqlens ? params.B_stride_batch : params.B_stride_mtp; + + int64_t const C_base = has_cu_seqlens ? (int64_t)bos * params.C_stride_batch + : (int64_t)seq_idx * params.C_stride_batch; + int64_t const C_tstride = has_cu_seqlens ? params.C_stride_batch : params.C_stride_mtp; + + int64_t const out_base = has_cu_seqlens ? (int64_t)bos * params.out_stride_batch + : (int64_t)seq_idx * params.out_stride_batch; + int64_t const out_tstride = has_cu_seqlens ? params.out_stride_batch : params.out_stride_mtp; + + int64_t const z_base = z ? (has_cu_seqlens ? (int64_t)bos * params.z_stride_batch + : (int64_t)seq_idx * params.z_stride_batch) + : 0; + int64_t const z_tstride = z ? (has_cu_seqlens ? params.z_stride_batch : params.z_stride_mtp) : 0; + constexpr auto stateRowsPerWarpPerStage = 4; constexpr auto stageRows = stateRowsPerWarpPerStage * numWarps; @@ -105,9 +162,12 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams d += warpSize * load_input_t::count) { if (dim_offset + d < DIM) { auto* dst = reinterpret_cast(&sram.x[mtp_step][d]); - *dst = *reinterpret_cast( - &x[batch * params.x_stride_batch + mtp_step * params.x_stride_mtp + head * DIM + - dim_offset + d]); + if (mtp_step < seq_len) { + *dst = *reinterpret_cast( + &x[x_base + mtp_step * x_tstride + head * DIM + dim_offset + d]); + } else { + *dst = make_zeros(); + } } } } @@ -115,9 +175,12 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams for (int mtp_step = 0; mtp_step < TOKENS_MTP; mtp_step++) { for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { auto* dst = reinterpret_cast(&sram.B[mtp_step][i]); - *dst = *reinterpret_cast( - &B[batch * params.B_stride_batch + mtp_step * params.B_stride_mtp + group * DSTATE + - i]); + if (mtp_step < seq_len) { + *dst = *reinterpret_cast( + &B[B_base + mtp_step * B_tstride + group * DSTATE + i]); + } else { + *dst = make_zeros(); + } } } } else if (warp == 2) { // Load z: gmem -> smem @@ -126,10 +189,12 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams d += warpSize * load_input_t::count) { if (dim_offset + d < DIM) { auto* dst = reinterpret_cast(&sram.z[mtp_step][d]); - *dst = z ? *reinterpret_cast( - &z[batch * params.z_stride_batch + mtp_step * params.z_stride_mtp + - head * DIM + dim_offset + d]) - : make_zeros(); + if (z && mtp_step < seq_len) { + *dst = *reinterpret_cast( + &z[z_base + mtp_step * z_tstride + head * DIM + dim_offset + d]); + } else { + *dst = make_zeros(); + } } } } @@ -139,26 +204,34 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams for (int mtp_step = 0; mtp_step < TOKENS_MTP; mtp_step++) { for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { auto* dst = reinterpret_cast(&sram.C[mtp_step][i]); - *dst = *reinterpret_cast( - &C[batch * params.C_stride_batch + mtp_step * params.C_stride_mtp + group * DSTATE + - i]); + if (mtp_step < seq_len) { + *dst = *reinterpret_cast( + &C[C_base + mtp_step * C_tstride + group * DSTATE + i]); + } else { + *dst = make_zeros(); + } } } } float rdt[TOKENS_MTP]; for (int step = 0; step < TOKENS_MTP; step++) { - auto dt_value = - dt_bias_value + - toFloat(dt[batch * params.dt_stride_batch + step * params.dt_stride_mtp + head]); - if (dt_softplus) { - dt_value = thresholded_softplus(dt_value); + if (step < seq_len) { + auto dt_value = dt_bias_value + toFloat(dt[dt_base + step * dt_tstride + head]); + if (dt_softplus) { + dt_value = thresholded_softplus(dt_value); + } + rdt[step] = dt_value; + } else { + rdt[step] = 0.f; } - rdt[step] = dt_value; } __syncthreads(); + bool const has_dst_indices = (dst_state_batch_indices != nullptr); + bool const has_intermediate = (intermediate_states != nullptr); + for (auto dBegin = 0; dBegin < ROWS_PER_BLOCK; dBegin += stageRows) { // Load state gmem -> smem for (int warpRow = 0; warpRow < stateRowsPerWarpPerStage; warpRow++) { @@ -222,6 +295,8 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams } for (int step = 0; step < TOKENS_MTP; step++) { + if (step >= seq_len) break; + float x_value = toFloat(sram.x[step][d - dim_offset]); float out_value = d_value * x_value * int(lane == 0); // first lane has the value @@ -254,7 +329,7 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams // Store intermediate state to smem (non-scaleState path) if constexpr (!scaleState) { if constexpr (sizeof(state_t) == sizeof(input_t)) { - if (intermediate_states) { + if (has_intermediate || has_dst_indices) { using packed_state_t = PackedAligned; packed_state_t rStateOut; // Philox-4x32 produces 4 random ints per call; amortize across packed elements. @@ -277,7 +352,7 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams *reinterpret_cast(&sram.state[dd][base_i]) = rStateOut; } } else { - if (intermediate_states) { + if (has_intermediate || has_dst_indices) { // Philox-4x32 produces 4 random ints per call; amortize across packed elements. [[maybe_unused]] uint32_t rand_ints[4]; #pragma unroll @@ -298,9 +373,9 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams } } - // For scaleState + intermediate_states: quantize rState → sram.state with block scaling + // For scaleState + per-step writes: quantize rState → sram.state with block scaling if constexpr (scaleState) { - if (intermediate_states && state_batch != params.pad_slot_id) { + if ((has_intermediate || has_dst_indices) && state_batch != params.pad_slot_id) { // 2-pass: compute max, then encode float istate_max = std::numeric_limits::lowest(); for (int ii = 0; ii < stateValuesPerThread; ii++) { @@ -333,35 +408,58 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams sram.out[step][d - dim_offset] = out_value; } - if (intermediate_states && state_batch != params.pad_slot_id) { - // Write intermediate state smem → gmem - for (int i = lane * load_state_t::count; i < DSTATE; - i += warpSize * load_state_t::count) { - auto* src = reinterpret_cast(&sram.state[dd][i]); - auto* dst = reinterpret_cast( - &intermediate_states[intermediate_cache_idx * - params.intermediate_state_stride_batch + - step * nheads * DIM * DSTATE + head * DIM * DSTATE + - d * DSTATE + i]); - *dst = *src; - } - // Write intermediate state decode scale → gmem - if constexpr (scaleState) { - if (lane == 0) { - auto* iscales = reinterpret_cast(params.intermediate_state_scales); - iscales[intermediate_cache_idx * params.intermediate_state_scales_stride_batch + - step * nheads * DIM + head * DIM + d] = sram.state_scale[dd]; + if (state_batch != params.pad_slot_id) { + if (has_dst_indices) { + auto dst_idx = static_cast( + dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch + + step * params.dst_state_batch_indices_stride_T]); + if (dst_idx != params.pad_slot_id) { + auto* dst_state_ptr = reinterpret_cast(params.state); + for (int i = lane * load_state_t::count; i < DSTATE; + i += warpSize * load_state_t::count) { + auto* src = reinterpret_cast(&sram.state[dd][i]); + *reinterpret_cast( + &dst_state_ptr[dst_idx * params.state_stride_batch + head * DIM * DSTATE + + d * DSTATE + i]) = *src; + } + if constexpr (scaleState) { + if (lane == 0) { + auto* dst_scale = reinterpret_cast(params.state_scale); + dst_scale[dst_idx * params.state_scale_stride_batch + head * DIM + d] = + sram.state_scale[dd]; + } + } + } + } else if (has_intermediate) { + // Write intermediate state smem → gmem + for (int i = lane * load_state_t::count; i < DSTATE; + i += warpSize * load_state_t::count) { + auto* src = reinterpret_cast(&sram.state[dd][i]); + auto* dst = reinterpret_cast( + &intermediate_states[intermediate_cache_idx * + params.intermediate_state_stride_batch + + step * nheads * DIM * DSTATE + head * DIM * DSTATE + + d * DSTATE + i]); + *dst = *src; + } + // Write intermediate state decode scale → gmem + if constexpr (scaleState) { + if (lane == 0) { + auto* iscales = reinterpret_cast(params.intermediate_state_scales); + iscales[intermediate_cache_idx * params.intermediate_state_scales_stride_batch + + step * nheads * DIM + head * DIM + d] = sram.state_scale[dd]; + } } } } } // Update state if enabled and not padded - if (params.update_state && state_batch != params.pad_slot_id) { + if (params.update_state && state_batch != params.pad_slot_id && !has_dst_indices) { // When intermediate_states is enabled, sram.state[dd] already holds the // stochastically-rounded (or scaled) state from the last token step's intermediate write. // Skip the redundant Philox PRNG / re-quantization and write directly to gmem. - if (!intermediate_states) { + if (!has_intermediate) { if constexpr (scaleState) { // 2-pass quantization: compute max, then re-encode float new_state_max = std::numeric_limits::lowest(); @@ -417,7 +515,7 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams } // Store state_scale smem -> gmem (contiguous across warpRows) if constexpr (scaleState) { - if (params.update_state && state_batch != params.pad_slot_id) { + if (params.update_state && state_batch != params.pad_slot_id && !has_dst_indices) { for (int warpRow = lane; warpRow < stateRowsPerWarpPerStage; warpRow += warpSize) { auto dd = warp * stateRowsPerWarpPerStage + warpRow; auto d = dim_offset + dBegin + dd; @@ -432,6 +530,7 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams __syncthreads(); for (auto step = warp; step < TOKENS_MTP; step += numWarps) { + if (step >= seq_len) continue; for (auto d = lane; d < ROWS_PER_BLOCK; d += warpSize) { if (dim_offset + d < DIM) { auto out_value = sram.out[step][d]; @@ -442,8 +541,7 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams out_value *= silu_z; } auto* dst = reinterpret_cast( - &output[batch * params.out_stride_batch + step * params.out_stride_mtp + head * DIM + - dim_offset + d]); + &output[out_base + step * out_tstride + head * DIM + dim_offset + d]); convertAndStore(dst, out_value); } } diff --git a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh index a9efc232ee..da66986e7d 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh @@ -113,12 +113,27 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams auto lane = threadIdx.x % warpSize; auto warp = threadIdx.y; - auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; + auto const state_batch = + state_batch_indices + ? static_cast( + state_batch_indices[batch * params.state_batch_indices_stride_batch]) + : static_cast(batch); + auto const* __restrict__ dst_sbi = + reinterpret_cast(params.dst_state_batch_indices); + auto const dst_state_batch = + dst_sbi ? static_cast(dst_sbi[batch * params.dst_state_batch_indices_stride_batch]) + : state_batch; auto const state_ptr_offset = state_batch * params.state_stride_batch + head * DIM * DSTATE; state += state_ptr_offset; + auto* __restrict__ dst_state = reinterpret_cast(params.state) + + dst_state_batch * params.state_stride_batch + head * DIM * DSTATE; if constexpr (scaleState) { state_scale += state_batch * params.state_scale_stride_batch + head * DIM; } + [[maybe_unused]] auto* __restrict__ dst_state_scale = + scaleState ? reinterpret_cast(params.state_scale) + + dst_state_batch * params.state_scale_stride_batch + head * DIM + : nullptr; __shared__ SharedStorageSimple sram; @@ -227,7 +242,7 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams out_value += new_state * C_value; } if (!scaleState && params.update_state && state_batch != params.pad_slot_id) { - *reinterpret_cast(&state[d * DSTATE + i]) = rState; + *reinterpret_cast(&dst_state[d * DSTATE + i]) = rState; } } out_value = warpReduceSum(out_value); @@ -252,7 +267,7 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams convertAndStore(&rState.val[ii], rNewState[iter * load_state_t::count + ii] * new_state_encode_scale); } - *reinterpret_cast(&state[d * DSTATE + i]) = rState; + *reinterpret_cast(&dst_state[d * DSTATE + i]) = rState; } if (lane == 0) { @@ -284,7 +299,7 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams auto _d = warp * rowsPerWarp + l; auto d = dim_offset + _d; if (d < DIM) { - state_scale[d] = sram.state_scale[_d]; + dst_state_scale[d] = sram.state_scale[_d]; } } } @@ -315,7 +330,7 @@ template ; #ifdef FLASHINFER_MAMBA_ENABLE_SM90 namespace cde = cuda::device::experimental; @@ -402,7 +417,7 @@ __device__ __forceinline__ void producer_func_vertical( cde::fence_proxy_async_shared_cta(); // Writeback if constexpr (writeState) { - cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, 0, d_write, head, batch, + cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, 0, d_write, head, dst_batch, &sram.state[stage][0]); cde::cp_async_bulk_commit_group(); @@ -433,7 +448,7 @@ __device__ __forceinline__ void producer_func_vertical( if constexpr (writeState) { // Unblock async proxy for writeback cde::fence_proxy_async_shared_cta(); - cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, 0, d_write, head, batch, + cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, 0, d_write, head, dst_batch, &sram.state[stage][0]); cde::cp_async_bulk_commit_group(); @@ -652,9 +667,21 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( auto lane = threadIdx.x % warpSize; auto warp = threadIdx.y; - auto const state_batch = (state_batch_indices) ? __ldg(&state_batch_indices[batch]) : batch; + auto const state_batch = + state_batch_indices + ? static_cast( + __ldg(&state_batch_indices[batch * params.state_batch_indices_stride_batch])) + : static_cast(batch); + auto const* __restrict__ dst_sbi = + reinterpret_cast(params.dst_state_batch_indices); + auto const dst_state_batch = + dst_sbi ? static_cast( + __ldg(&dst_sbi[batch * params.dst_state_batch_indices_stride_batch])) + : state_batch; auto const state_ptr_offset = static_cast(state_batch) * params.state_stride_batch + head * DIM * DSTATE; + auto const dst_state_ptr_offset = + static_cast(dst_state_batch) * params.state_stride_batch + head * DIM * DSTATE; extern __shared__ uint8_t sbuffer[]; using sram_t = SharedStorageVertical( sram, tensorState, x_global_ptr, B_global_ptr, C_global_ptr, - hasZ ? z_global_ptr : nullptr, state_scale_ptr, state_batch, head); + hasZ ? z_global_ptr : nullptr, state_scale_ptr, state_batch, dst_state_batch, head); }; auto const dispatch_state = [&]() { if (read_state && write_state) @@ -760,7 +787,7 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( if constexpr (scaleState) { if (params.update_state && state_batch != params.pad_slot_id) { if (d < DIM) { - state_scale[state_batch * params.state_scale_stride_batch + head * DIM + d] = + state_scale[dst_state_batch * params.state_scale_stride_batch + head * DIM + d] = sram.state_scale[d]; } } @@ -789,7 +816,7 @@ template __device__ __forceinline__ void producer_func_horizontal(SramT& sram, CUtensorMap const& tensorState, int batch, - int head) { + int dst_batch, int head) { namespace cde = cuda::device::experimental; auto constexpr stagesReadOnly = numStages; @@ -830,7 +857,7 @@ __device__ __forceinline__ void producer_func_horizontal(SramT& sram, cde::fence_proxy_async_shared_cta(); // Writeback if constexpr (writeState) { - cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, i_write, 0, head, batch, + cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, i_write, 0, head, dst_batch, &sram.state[stage][0]); cde::cp_async_bulk_commit_group(); cde::cp_async_bulk_wait_group_read<0>(); @@ -860,7 +887,7 @@ __device__ __forceinline__ void producer_func_horizontal(SramT& sram, if constexpr (writeState) { // Unblock async proxy for writeback cde::fence_proxy_async_shared_cta(); - cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, i_write, 0, head, batch, + cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, i_write, 0, head, dst_batch, &sram.state[stage][0]); cde::cp_async_bulk_commit_group(); cde::cp_async_bulk_wait_group_read<0>(); @@ -1028,7 +1055,16 @@ __global__ void selective_state_update_kernel_producer_consumer_horizontal( auto lane = threadIdx.x % warpSize; auto warp = threadIdx.y; - auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; + auto const state_batch = + state_batch_indices + ? static_cast( + state_batch_indices[batch * params.state_batch_indices_stride_batch]) + : static_cast(batch); + auto const* __restrict__ dst_sbi = + reinterpret_cast(params.dst_state_batch_indices); + auto const dst_state_batch = + dst_sbi ? static_cast(dst_sbi[batch * params.dst_state_batch_indices_stride_batch]) + : state_batch; auto const state_ptr_offset = static_cast(state_batch) * params.state_stride_batch + head * DIM * DSTATE; @@ -1061,13 +1097,13 @@ __global__ void selective_state_update_kernel_producer_consumer_horizontal( cg::invoke_one(cg::coalesced_threads(), [&]() { if (read_state && write_state) producer_func_horizontal( - sram, tensorState, state_batch, head); + sram, tensorState, state_batch, dst_state_batch, head); else if (read_state && !write_state) producer_func_horizontal( - sram, tensorState, state_batch, head); + sram, tensorState, state_batch, dst_state_batch, head); else producer_func_horizontal( - sram, tensorState, state_batch, head); + sram, tensorState, state_batch, dst_state_batch, head); }); } else { // consumers diff --git a/include/flashinfer/mamba/selective_state_update.cuh b/include/flashinfer/mamba/selective_state_update.cuh index 83eda2712c..55aef94b47 100644 --- a/include/flashinfer/mamba/selective_state_update.cuh +++ b/include/flashinfer/mamba/selective_state_update.cuh @@ -66,6 +66,14 @@ struct SelectiveStateUpdateParams { // Block-scale decode factors for quantized state: float32 (state_cache_size, nheads, dim, 1) void* __restrict__ state_scale{nullptr}; + void* __restrict__ dst_state_batch_indices{nullptr}; + + // stride_T=0 means 1D (broadcast), stride_T>0 means 2D indexing + int64_t state_batch_indices_stride_batch{1}; + int64_t state_batch_indices_stride_T{0}; + int64_t dst_state_batch_indices_stride_batch{0}; + int64_t dst_state_batch_indices_stride_T{0}; + bool dt_softplus{false}; bool update_state{true}; @@ -90,6 +98,9 @@ struct SelectiveStateMTPParams : public SelectiveStateUpdateParams { void* __restrict__ intermediate_state_indices{nullptr}; // (batch,) void* __restrict__ intermediate_state_scales{ nullptr}; // float: (batch, cache_steps, nheads, dim) + + void* __restrict__ cu_seqlens{nullptr}; // (n_sequences + 1,) + void* __restrict__ num_accepted_tokens{nullptr}; // (n_sequences,) }; } // namespace mtp diff --git a/tests/mamba/test_selective_state_update_varlen.py b/tests/mamba/test_selective_state_update_varlen.py new file mode 100644 index 0000000000..620ec48e2b --- /dev/null +++ b/tests/mamba/test_selective_state_update_varlen.py @@ -0,0 +1,539 @@ +""" +Varlen / speculative-decoding tests for selective_state_update. + +Tests the new features for speculative decoding integration: +- dst_state_batch_indices: separate read/write state cache slots +- cu_seqlens: variable-length sequences (tokens flattened into batch dim) +- num_accepted_tokens: initial state selection for speculative decoding +- 2D state_batch_indices: (N, max_seqlen) index tensors +""" + +import numpy as np +import pytest +import torch + +import flashinfer + +from .triton_reference.selective_state_update_varlen import ( + selective_state_update_varlen_triton, +) + + +PAD_SLOT_ID = -1 + + +def _make_base_tensors( + total_tokens, + nheads, + dim, + dstate, + ngroups, + state_cache_size, + input_dtype=torch.bfloat16, + weight_dtype=torch.float32, + matrixA_dtype=torch.float32, + state_dtype=torch.bfloat16, + device="cuda", +): + """Create base input tensors for total_tokens in varlen (3D) layout.""" + x = torch.randn(total_tokens, nheads, dim, device=device, dtype=input_dtype) + + dt_base = torch.randn(total_tokens, nheads, device=device, dtype=weight_dtype) + dt = dt_base.as_strided((total_tokens, nheads, dim), (nheads, 1, 0)) + + A_base = -torch.rand(nheads, device=device, dtype=matrixA_dtype) - 1.0 + A = A_base.as_strided((nheads, dim, dstate), (1, 0, 0)) + + B = torch.randn(total_tokens, ngroups, dstate, device=device, dtype=input_dtype) + C = torch.randn(total_tokens, ngroups, dstate, device=device, dtype=input_dtype) + + D_base = torch.randn(nheads, device=device, dtype=weight_dtype) + D = D_base.as_strided((nheads, dim), (1, 0)) + + dt_bias_base = torch.rand(nheads, device=device, dtype=weight_dtype) - 4.0 + dt_bias = dt_bias_base.as_strided((nheads, dim), (1, 0)) + + state = torch.randn( + state_cache_size, nheads, dim, dstate, device=device, dtype=state_dtype + ) + + return dict(state=state, x=x, dt=dt, A=A, B=B, C=C, D=D, dt_bias=dt_bias) + + +def _assert_match(ref, test, name, atol=1e-3, rtol=1e-2): + """Assert tensors match with detailed error reporting.""" + match = torch.allclose(ref, test, atol=atol, rtol=rtol) + if match: + print(f" {name}: PASSED") + else: + ref_np = ref.detach().cpu().float().numpy() + test_np = test.detach().cpu().float().numpy() + mismatch = ~np.isclose(ref_np, test_np, atol=atol, rtol=rtol) + num_mismatch = np.sum(mismatch) + print( + f" {name}: FAILED ({num_mismatch}/{ref_np.size} elements differ, " + f"max diff = {(ref - test).abs().max().item():.6e})" + ) + assert match, f"{name} mismatch: max diff = {(ref - test).abs().max().item()}" + + +class TestSelectiveStateUpdateDstIndices: + """Test dst_state_batch_indices in single-token (STP) path.""" + + ATOL = 1e-3 + RTOL = 1e-2 + NHEADS = 64 + DIM = 64 + DSTATE = 128 + NGROUPS = 8 + STATE_CACHE_SIZE = 256 + + @pytest.mark.parametrize("batch", [1, 4, 32, 64]) + def test_dst_different_from_src(self, batch): + """State is read from src slots and written to disjoint dst slots.""" + torch.manual_seed(42) + tensors = _make_base_tensors( + batch, + self.NHEADS, + self.DIM, + self.DSTATE, + self.NGROUPS, + self.STATE_CACHE_SIZE, + ) + out = torch.empty( + batch, self.NHEADS, self.DIM, device="cuda", dtype=torch.bfloat16 + ) + + perm = torch.randperm(self.STATE_CACHE_SIZE, device="cuda") + src_indices = perm[:batch].to(torch.int32) + dst_indices = perm[batch : 2 * batch].to(torch.int32) + + src_2d = src_indices.unsqueeze(1) + dst_2d = dst_indices.unsqueeze(1) + + state_ref = tensors["state"].clone() + out_ref = out.clone() + selective_state_update_varlen_triton( + state_ref, + tensors["x"], + tensors["dt"], + tensors["A"], + tensors["B"], + tensors["C"], + D=tensors["D"], + dt_bias=tensors["dt_bias"], + dt_softplus=True, + state_batch_indices=src_2d, + dst_state_batch_indices=dst_2d, + pad_slot_id=PAD_SLOT_ID, + out=out_ref, + ) + + state_test = tensors["state"].clone() + out_test = out.clone() + flashinfer.mamba.selective_state_update( + state_test, + tensors["x"], + tensors["dt"], + tensors["A"], + tensors["B"], + tensors["C"], + D=tensors["D"], + dt_bias=tensors["dt_bias"], + dt_softplus=True, + state_batch_indices=src_2d, + dst_state_batch_indices=dst_2d, + pad_slot_id=PAD_SLOT_ID, + out=out_test, + ) + + _assert_match(out_ref, out_test, "output", self.ATOL, self.RTOL) + _assert_match( + state_ref[dst_indices.long()], + state_test[dst_indices.long()], + "dst_state", + self.ATOL, + self.RTOL, + ) + + src_orig = tensors["state"][src_indices.long()] + src_after = state_test[src_indices.long()] + assert torch.equal(src_orig, src_after), "Source state slots were modified" + + +class TestSelectiveStateUpdateDstIndices2D: + """Test 2D state_batch_indices with shape (batch, 1).""" + + ATOL = 1e-3 + RTOL = 1e-2 + NHEADS = 64 + DIM = 64 + DSTATE = 128 + NGROUPS = 8 + STATE_CACHE_SIZE = 256 + + @pytest.mark.parametrize("batch", [1, 16, 64]) + def test_2d_indices_seqlen1(self, batch): + """2D indices with max_seqlen=1 should behave identically to STP.""" + torch.manual_seed(42) + tensors = _make_base_tensors( + batch, + self.NHEADS, + self.DIM, + self.DSTATE, + self.NGROUPS, + self.STATE_CACHE_SIZE, + ) + out = torch.empty( + batch, self.NHEADS, self.DIM, device="cuda", dtype=torch.bfloat16 + ) + + perm = torch.randperm(self.STATE_CACHE_SIZE, device="cuda") + src_indices = perm[:batch].to(torch.int32).unsqueeze(1) + dst_indices = perm[batch : 2 * batch].to(torch.int32).unsqueeze(1) + + state_ref = tensors["state"].clone() + out_ref = out.clone() + selective_state_update_varlen_triton( + state_ref, + tensors["x"], + tensors["dt"], + tensors["A"], + tensors["B"], + tensors["C"], + D=tensors["D"], + dt_bias=tensors["dt_bias"], + dt_softplus=True, + state_batch_indices=src_indices, + dst_state_batch_indices=dst_indices, + pad_slot_id=PAD_SLOT_ID, + out=out_ref, + ) + + state_test = tensors["state"].clone() + out_test = out.clone() + flashinfer.mamba.selective_state_update( + state_test, + tensors["x"], + tensors["dt"], + tensors["A"], + tensors["B"], + tensors["C"], + D=tensors["D"], + dt_bias=tensors["dt_bias"], + dt_softplus=True, + state_batch_indices=src_indices, + dst_state_batch_indices=dst_indices, + pad_slot_id=PAD_SLOT_ID, + out=out_test, + ) + + _assert_match(out_ref, out_test, "output", self.ATOL, self.RTOL) + + dst_long = dst_indices.squeeze(1).long() + _assert_match( + state_ref[dst_long], + state_test[dst_long], + "dst_state", + self.ATOL, + self.RTOL, + ) + + +class TestSelectiveStateUpdateVarlen: + """Test varlen (cu_seqlens) multi-token support.""" + + ATOL = 1e-3 + RTOL = 1e-2 + NHEADS = 64 + DIM = 64 + DSTATE = 128 + NGROUPS = 8 + STATE_CACHE_SIZE = 512 + + @pytest.mark.parametrize( + "n_seqs,max_seqlen", + [ + (1, 1), + (1, 4), + (4, 1), + (4, 2), + (4, 4), + (16, 2), + (16, 4), + ], + ) + def test_varlen_uniform(self, n_seqs, max_seqlen): + """All sequences have the same length.""" + torch.manual_seed(42) + total_tokens = n_seqs * max_seqlen + tensors = _make_base_tensors( + total_tokens, + self.NHEADS, + self.DIM, + self.DSTATE, + self.NGROUPS, + self.STATE_CACHE_SIZE, + ) + + cu_seqlens = torch.arange( + 0, total_tokens + 1, max_seqlen, device="cuda", dtype=torch.int32 + ) + + perm = torch.randperm(self.STATE_CACHE_SIZE, device="cuda") + src_indices = ( + perm[: n_seqs * max_seqlen].reshape(n_seqs, max_seqlen).to(torch.int32) + ) + dst_indices = ( + perm[n_seqs * max_seqlen : 2 * n_seqs * max_seqlen] + .reshape(n_seqs, max_seqlen) + .to(torch.int32) + ) + + num_accepted = torch.ones(n_seqs, device="cuda", dtype=torch.int64) + out = torch.empty( + total_tokens, self.NHEADS, self.DIM, device="cuda", dtype=torch.bfloat16 + ) + + state_ref = tensors["state"].clone() + out_ref = out.clone() + selective_state_update_varlen_triton( + state_ref, + tensors["x"], + tensors["dt"], + tensors["A"], + tensors["B"], + tensors["C"], + D=tensors["D"], + dt_bias=tensors["dt_bias"], + dt_softplus=True, + state_batch_indices=src_indices, + dst_state_batch_indices=dst_indices, + pad_slot_id=PAD_SLOT_ID, + out=out_ref, + num_accepted_tokens=num_accepted, + cu_seqlens=cu_seqlens, + ) + + state_test = tensors["state"].clone() + out_test = out.clone() + flashinfer.mamba.selective_state_update( + state_test, + tensors["x"], + tensors["dt"], + tensors["A"], + tensors["B"], + tensors["C"], + D=tensors["D"], + dt_bias=tensors["dt_bias"], + dt_softplus=True, + state_batch_indices=src_indices, + dst_state_batch_indices=dst_indices, + pad_slot_id=PAD_SLOT_ID, + out=out_test, + num_accepted_tokens=num_accepted, + cu_seqlens=cu_seqlens, + cache_steps=max_seqlen, + ) + + _assert_match(out_ref, out_test, "output", self.ATOL, self.RTOL) + + for s in range(n_seqs): + for t in range(max_seqlen): + dst_slot = dst_indices[s, t].long().item() + _assert_match( + state_ref[dst_slot], + state_test[dst_slot], + f"state[seq={s},token={t},slot={dst_slot}]", + self.ATOL, + self.RTOL, + ) + + @pytest.mark.parametrize("n_seqs", [4, 8]) + def test_varlen_variable_lengths(self, n_seqs): + """Sequences have different lengths (padded with PAD_SLOT_ID).""" + max_seqlen = 6 + torch.manual_seed(42) + + seq_lens = torch.randint(1, max_seqlen + 1, (n_seqs,), device="cuda") + total_tokens = seq_lens.sum().item() + cu_seqlens = torch.zeros(n_seqs + 1, device="cuda", dtype=torch.int32) + cu_seqlens[1:] = torch.cumsum(seq_lens, dim=0).to(torch.int32) + + tensors = _make_base_tensors( + total_tokens, + self.NHEADS, + self.DIM, + self.DSTATE, + self.NGROUPS, + self.STATE_CACHE_SIZE, + ) + + src_indices = torch.full( + (n_seqs, max_seqlen), PAD_SLOT_ID, device="cuda", dtype=torch.int32 + ) + dst_indices = torch.full( + (n_seqs, max_seqlen), PAD_SLOT_ID, device="cuda", dtype=torch.int32 + ) + perm = torch.randperm(self.STATE_CACHE_SIZE, device="cuda").to(torch.int32) + slot_offset = 0 + for s in range(n_seqs): + sl = seq_lens[s].item() + src_indices[s, :sl] = perm[slot_offset : slot_offset + sl] + dst_indices[s, :sl] = perm[ + slot_offset + self.STATE_CACHE_SIZE // 2 : slot_offset + + self.STATE_CACHE_SIZE // 2 + + sl + ] + slot_offset += sl + + num_accepted = torch.ones(n_seqs, device="cuda", dtype=torch.int64) + out = torch.empty( + total_tokens, self.NHEADS, self.DIM, device="cuda", dtype=torch.bfloat16 + ) + + state_ref = tensors["state"].clone() + out_ref = out.clone() + selective_state_update_varlen_triton( + state_ref, + tensors["x"], + tensors["dt"], + tensors["A"], + tensors["B"], + tensors["C"], + D=tensors["D"], + dt_bias=tensors["dt_bias"], + dt_softplus=True, + state_batch_indices=src_indices, + dst_state_batch_indices=dst_indices, + pad_slot_id=PAD_SLOT_ID, + out=out_ref, + num_accepted_tokens=num_accepted, + cu_seqlens=cu_seqlens, + ) + + state_test = tensors["state"].clone() + out_test = out.clone() + flashinfer.mamba.selective_state_update( + state_test, + tensors["x"], + tensors["dt"], + tensors["A"], + tensors["B"], + tensors["C"], + D=tensors["D"], + dt_bias=tensors["dt_bias"], + dt_softplus=True, + state_batch_indices=src_indices, + dst_state_batch_indices=dst_indices, + pad_slot_id=PAD_SLOT_ID, + out=out_test, + num_accepted_tokens=num_accepted, + cu_seqlens=cu_seqlens, + cache_steps=max_seqlen, + ) + + _assert_match(out_ref, out_test, "output", self.ATOL, self.RTOL) + + +class TestSelectiveStateUpdateNumAcceptedTokens: + """Test num_accepted_tokens for initial state selection.""" + + ATOL = 1e-3 + RTOL = 1e-2 + NHEADS = 64 + DIM = 64 + DSTATE = 128 + NGROUPS = 8 + STATE_CACHE_SIZE = 512 + + @pytest.mark.parametrize("n_seqs", [4, 8, 16]) + @pytest.mark.parametrize("num_accepted_dtype", [torch.int32, torch.int64]) + def test_num_accepted_selects_initial_state(self, n_seqs, num_accepted_dtype): + """num_accepted_tokens controls which state slot to read as initial.""" + max_seqlen = 4 + total_tokens = n_seqs * max_seqlen + torch.manual_seed(42) + + tensors = _make_base_tensors( + total_tokens, + self.NHEADS, + self.DIM, + self.DSTATE, + self.NGROUPS, + self.STATE_CACHE_SIZE, + ) + + cu_seqlens = torch.arange( + 0, total_tokens + 1, max_seqlen, device="cuda", dtype=torch.int32 + ) + + num_accepted = torch.randint( + 1, max_seqlen + 1, (n_seqs,), device="cuda", dtype=num_accepted_dtype + ) + + perm = torch.randperm(self.STATE_CACHE_SIZE, device="cuda").to(torch.int32) + src_indices = perm[: n_seqs * max_seqlen].reshape(n_seqs, max_seqlen) + dst_indices = perm[n_seqs * max_seqlen : 2 * n_seqs * max_seqlen].reshape( + n_seqs, max_seqlen + ) + + out = torch.empty( + total_tokens, self.NHEADS, self.DIM, device="cuda", dtype=torch.bfloat16 + ) + + state_ref = tensors["state"].clone() + out_ref = out.clone() + selective_state_update_varlen_triton( + state_ref, + tensors["x"], + tensors["dt"], + tensors["A"], + tensors["B"], + tensors["C"], + D=tensors["D"], + dt_bias=tensors["dt_bias"], + dt_softplus=True, + state_batch_indices=src_indices, + dst_state_batch_indices=dst_indices, + pad_slot_id=PAD_SLOT_ID, + out=out_ref, + num_accepted_tokens=num_accepted, + cu_seqlens=cu_seqlens, + ) + + state_test = tensors["state"].clone() + out_test = out.clone() + flashinfer.mamba.selective_state_update( + state_test, + tensors["x"], + tensors["dt"], + tensors["A"], + tensors["B"], + tensors["C"], + D=tensors["D"], + dt_bias=tensors["dt_bias"], + dt_softplus=True, + state_batch_indices=src_indices, + dst_state_batch_indices=dst_indices, + pad_slot_id=PAD_SLOT_ID, + out=out_test, + num_accepted_tokens=num_accepted, + cu_seqlens=cu_seqlens, + cache_steps=max_seqlen, + ) + + _assert_match(out_ref, out_test, "output", self.ATOL, self.RTOL) + + for s in range(n_seqs): + for t in range(max_seqlen): + dst_slot = dst_indices[s, t].long().item() + if dst_slot == PAD_SLOT_ID: + continue + _assert_match( + state_ref[dst_slot], + state_test[dst_slot], + f"state[seq={s},token={t},num_accepted={num_accepted[s].item()}]", + self.ATOL, + self.RTOL, + ) diff --git a/tests/mamba/triton_reference/selective_state_update_varlen.py b/tests/mamba/triton_reference/selective_state_update_varlen.py new file mode 100644 index 0000000000..11a59f341e --- /dev/null +++ b/tests/mamba/triton_reference/selective_state_update_varlen.py @@ -0,0 +1,467 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py + +import torch +from packaging import version + +import triton +import triton.language as tl + +PAD_SLOT_ID = -1 + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + +if TRITON3: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt +else: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt + + +@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics( + { + "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] + is not None + } +) +@triton.heuristics( + {"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens_ptr"] is not None} +) +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens_ptr"] is not None}) +@triton.heuristics( + {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])} +) +@triton.jit(do_not_specialize=["N"]) +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, + x_ptr, + dt_ptr, + dt_bias_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + out_ptr, + state_batch_indices_ptr, + dst_state_batch_indices_ptr, + pad_slot_id, + num_accepted_tokens_ptr, + cu_seqlens_ptr, + # Matrix dimensions + N, + nheads, + dim, + dstate, + nheads_ngroups_ratio, + # Strides + stride_state_batch, + stride_state_head, + stride_state_dim, + stride_state_dstate, + stride_x_batch, + stride_x_head, + stride_x_dim, + stride_dt_batch, + stride_dt_head, + stride_dt_dim, + stride_dt_bias_head, + stride_dt_bias_dim, + stride_A_head, + stride_A_dim, + stride_A_dstate, + stride_B_batch, + stride_B_group, + stride_B_dstate, + stride_C_batch, + stride_C_group, + stride_C_dstate, + stride_D_head, + stride_D_dim, + stride_z_batch, + stride_z_head, + stride_z_dim, + stride_out_batch, + stride_out_head, + stride_out_dim, + stride_state_indices_batch, + stride_state_indices_T, + stride_dst_state_indices_batch, + stride_dst_state_indices_T, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_STATE_BATCH_INDICES: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + IS_VARLEN: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + + if IS_VARLEN: + bos = tl.load(cu_seqlens_ptr + pid_b).to(tl.int64) + eos = tl.load(cu_seqlens_ptr + pid_b + 1).to(tl.int64) + seq_len = eos - bos + + if seq_len == 0: + return + else: + bos = pid_b + seq_len = 1 + + state_ptr_base = state_ptr + + if HAS_STATE_BATCH_INDICES: + if IS_SPEC_DECODING: + num_accepted = tl.load(num_accepted_tokens_ptr + pid_b).to(tl.int64) + init_token_idx = tl.maximum(num_accepted - 1, 0) + else: + init_token_idx = 0 + + dst_state_batch_indices_ptr += pid_b * stride_dst_state_indices_batch + if not IS_SPEC_DECODING: + dst_state_batch_idx = tl.load( + dst_state_batch_indices_ptr + + init_token_idx * stride_dst_state_indices_T + ).to(tl.int64) + dst_state_ptr = state_ptr + ( + dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head + ) + + state_batch_indices_ptr += ( + pid_b * stride_state_indices_batch + init_token_idx * stride_state_indices_T + ) + state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) + state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head + else: + dst_state_ptr = ( + state_ptr + pid_b * stride_state_batch + pid_h * stride_state_head + ) + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + + x_ptr += bos * stride_x_batch + pid_h * stride_x_head + dt_ptr += bos * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head + B_ptr += bos * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr += bos * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += bos * stride_z_batch + pid_h * stride_z_head + out_ptr += bos * stride_out_batch + pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + ( + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + ) + if not IS_SPEC_DECODING: + dst_state_ptrs = dst_state_ptr + ( + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + ) + + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= state_batch_idx != pad_slot_id + state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32) + + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + D_ptrs = D_ptr + offs_m * stride_D_dim + A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate + + for i_t in range(seq_len): + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load( + A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + + if IS_SPEC_DECODING: + dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T + token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64) + if token_dst_idx != pad_slot_id: + token_dst_ptrs = ( + state_ptr_base + + token_dst_idx * stride_state_batch + + pid_h * stride_state_head + + offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate + ) + tl.store( + token_dst_ptrs, + state.to(token_dst_ptrs.dtype.element_ty), + mask=mask, + ) + + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + x_ptr += stride_x_batch + dt_ptr += stride_dt_batch + B_ptr += stride_B_batch + C_ptr += stride_C_batch + out_ptr += stride_out_batch + if HAS_Z: + z_ptr += stride_z_batch + + if not IS_SPEC_DECODING: + tl.store(dst_state_ptrs, state.to(dst_state_ptrs.dtype.element_ty), mask=mask) + + +def selective_state_update_varlen_triton( + state, + x, + dt, + A, + B, + C, + D=None, + dt_bias=None, + z=None, + dt_softplus=False, + state_batch_indices=None, + dst_state_batch_indices=None, + pad_slot_id=PAD_SLOT_ID, + out=None, + num_accepted_tokens=None, + cu_seqlens=None, +): + """ + Selective state update with varlen / speculative decoding support. + + Arguments: + state: (state_cache_size, nheads, dim, dstate) + x: (total_tokens, nheads, dim) + dt: (total_tokens, nheads, dim) + A: (nheads, dim, dstate) + B: (total_tokens, ngroups, dstate) + C: (total_tokens, ngroups, dstate) + D: (nheads, dim) + dt_bias: (nheads, dim) + z: (total_tokens, nheads, dim) + state_batch_indices: (N, max_seqlen) — source state cache indices + dst_state_batch_indices: (N, max_seqlen) — destination state cache indices + num_accepted_tokens: (N,) — determines initial state index per sequence + cu_seqlens: (N + 1,) — cumulative sequence lengths + """ + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + if out is not None and out.dim() == 2: + out = out.unsqueeze(1) + if num_accepted_tokens is not None: + assert state_batch_indices is not None and state_batch_indices.dim() == 2 + assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2 + if state_batch_indices is not None and state_batch_indices.dim() == 1: + state_batch_indices = state_batch_indices.unsqueeze(1) + if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1: + dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1) + + _, nheads, dim, dstate = state.shape + batch = x.shape[0] + if cu_seqlens is not None: + N = len(cu_seqlens) - 1 + max_seqlen = ( + state_batch_indices.size(-1) if state_batch_indices is not None else 1 + ) + else: + N = batch + max_seqlen = 1 + + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + if state_batch_indices is not None: + assert state_batch_indices.shape[0] >= N + assert state_batch_indices.shape[1] >= max_seqlen + if dst_state_batch_indices is not None: + assert dst_state_batch_indices.shape[0] >= N + assert dst_state_batch_indices.shape[1] >= max_seqlen + else: + dst_state_batch_indices = state_batch_indices + if out is None: + out = torch.empty_like(x) + assert out.shape == x.shape + if num_accepted_tokens is not None: + assert num_accepted_tokens.shape == (N,) + + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), N, nheads) + z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) + state_batch_indices_strides = ( + (state_batch_indices.stride(0), state_batch_indices.stride(1)) + if state_batch_indices is not None + else (0, 0) + ) + dst_state_batch_indices_strides = ( + (dst_state_batch_indices.stride(0), dst_state_batch_indices.stride(1)) + if dst_state_batch_indices is not None + else (0, 0) + ) + + BLOCK_SIZE_M, num_warps = 4, 8 + if dstate <= 16: + BLOCK_SIZE_M, num_warps = 32, 4 + elif dstate <= 32: + BLOCK_SIZE_M, num_warps = 16, 4 + elif dstate <= 64: + BLOCK_SIZE_M, num_warps = 8, 4 + elif dstate <= 128: + BLOCK_SIZE_M, num_warps = 4, 4 + + dt_bias_strides = ( + (dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else (0, 0) + ) + + tie_hdim = ( + A.stride(-1) == 0 + and A.stride(-2) == 0 + and dt.stride(-1) == 0 + and (dt_bias is None or dt_bias.stride(-1) == 0) + ) + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, + x, + dt, + dt_bias, + A, + B, + C, + D, + z, + out, + state_batch_indices, + dst_state_batch_indices, + pad_slot_id, + num_accepted_tokens, + cu_seqlens, + N, + nheads, + dim, + dstate, + nheads // ngroups, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + dt.stride(0), + dt.stride(1), + dt.stride(2), + dt_bias_strides[0], + dt_bias_strides[1], + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + C.stride(0), + C.stride(1), + C.stride(2), + D.stride(0) if D is not None else 0, + D.stride(1) if D is not None else 0, + z_strides[0], + z_strides[1], + z_strides[2], + out.stride(0), + out.stride(1), + out.stride(2), + state_batch_indices_strides[0], + state_batch_indices_strides[1], + dst_state_batch_indices_strides[0], + dst_state_batch_indices_strides[1], + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + return out