diff --git a/benchmarks/bench_ssu_sweep_mtp.py b/benchmarks/bench_ssu_sweep_mtp.py index 2510985cf4..65a7c373de 100644 --- a/benchmarks/bench_ssu_sweep_mtp.py +++ b/benchmarks/bench_ssu_sweep_mtp.py @@ -57,7 +57,14 @@ def create_benchmark_inputs( def benchmark_kernel( - name, kernel_fn, inputs, cache_steps=0, ncu=False, rand_seed=None, philox_rounds=10 + name, + kernel_fn, + inputs, + cache_steps=0, + ncu=False, + rand_seed=None, + philox_rounds=10, + repeat_time_ms=1000, ): """Benchmark a single kernel and return median time in ms.""" print(f"\n Benchmarking {name}...") @@ -87,6 +94,7 @@ def benchmark_kernel( cache_steps=cache_steps, rand_seed=rand_seed, philox_rounds=philox_rounds, + disable_state_update=True, ) if ncu: @@ -99,7 +107,7 @@ def benchmark_kernel( measurements = bench_gpu_time( lambda: kernel_fn(**kwargs), dry_run_time_ms=100, - repeat_time_ms=1000, + repeat_time_ms=repeat_time_ms, ) except RuntimeError as e: print(f" Kernel failed: {e}") @@ -132,6 +140,7 @@ def wrapper( cache_steps, rand_seed, philox_rounds, + disable_state_update, ): selective_state_update_triton( state=state, @@ -151,6 +160,8 @@ def wrapper( cache_steps=cache_steps, intermediate_state_indices=intermediate_state_indices, rand_seed=rand_seed, + philox_rounds=philox_rounds, + disable_state_update=disable_state_update, ) return wrapper @@ -178,6 +189,7 @@ def wrapper( cache_steps, rand_seed, philox_rounds, + disable_state_update, ): flashinfer_selective_state_update( state=state, @@ -199,6 +211,7 @@ def wrapper( algorithm=algorithm, rand_seed=rand_seed, philox_rounds=philox_rounds, + disable_state_update=disable_state_update, ) return wrapper @@ -215,6 +228,7 @@ def run_measurement( generate_intermediate_states_buffer=False, ncu=False, philox_rounds=None, + repeat_time_ms=1000, ): """Run benchmarks on all kernels and return results dict.""" inputs = create_benchmark_inputs( @@ -242,9 +256,6 @@ def run_measurement( "flashinfer_simple": make_flashinfer_wrapper(algorithm="simple"), "flashinfer_vertical": make_flashinfer_wrapper(algorithm="vertical"), "flashinfer_horizontal": make_flashinfer_wrapper(algorithm="horizontal"), - "flashinfer_async_horizontal": make_flashinfer_wrapper( - algorithm="async_horizontal" - ), "flashinfer_auto": make_flashinfer_wrapper(algorithm="auto"), } @@ -258,6 +269,7 @@ def run_measurement( ncu=ncu, rand_seed=rand_seed, philox_rounds=effective_philox_rounds, + repeat_time_ms=repeat_time_ms, ) results[name] = median_time @@ -312,6 +324,13 @@ def run_measurement( default=6, help="Number of MTP (cache) steps (default: 6)", ) +parser.add_argument( + "-r", + "--repeat", + type=int, + default=1000, + help="Repeat time in milliseconds for benchmarking (default: 1000)", +) args = parser.parse_args() # Powers of two from 1 to 2048 @@ -373,6 +392,7 @@ def parse_dtype_spec(spec): generate_intermediate_states_buffer=True, ncu=args.ncu, philox_rounds=philox_rounds, + repeat_time_ms=args.repeat, ) if not results: diff --git a/benchmarks/bench_ssu_sweep_sol.py b/benchmarks/bench_ssu_sweep_sol.py new file mode 100644 index 0000000000..0976908be2 --- /dev/null +++ b/benchmarks/bench_ssu_sweep_sol.py @@ -0,0 +1,682 @@ +#!/usr/bin/env python3 +""" +Benchmark selective_state_update (MTP mode) — % of Speed-of-Light (SOL). + +Measures FlashInfer kernel achieved memory bandwidth as a percentage of the +GPU's peak HBM bandwidth. This is the right metric for memory-bound kernels. + +Methodology follows benchmarks/routines/mamba.py: problem_bytes (read + write) +divided by kernel time gives achieved TB/s, then SOL% = achieved / peak * 100. +""" + +import argparse +import re +import sys +from pathlib import Path + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import torch +from flashinfer.testing import bench_gpu_time + +# Add tests directory to path for create_test_inputs +sys.path.insert(0, str(Path(__file__).parent.parent / "tests" / "mamba")) + +from utils import create_test_inputs, clone_preserving_strides +from flashinfer.mamba import selective_state_update as flashinfer_selective_state_update + + +# Peak HBM bandwidth in TB/s for known GPUs (bidirectional) +# Source: NVIDIA product specs +_PEAK_BW_TB_S = { + "H100 SXM": 3.35, + "H100 PCIe": 2.0, + "H100 NVL": 3.35, + "H200": 4.8, + "A100 SXM": 2.0, + "A100 PCIe": 1.555, + "A100-SXM4-80GB": 2.0, + "A100-SXM4-40GB": 1.555, + "B200": 8.0, + "B100": 8.0, + "L40S": 0.864, + "L40": 0.864, + "A10": 0.6, +} + +# Peak SIMT (non-tensor-core) FP32 throughput in TFLOPS +# Source: NVIDIA product specs — these are the CUDA core (SIMT) numbers, +# NOT tensor core numbers. +_PEAK_SIMT_FP32_TFLOPS = { + "H100 SXM": 67.0, + "H100 PCIe": 51.2, + "H100 NVL": 67.0, + "H200": 67.0, # same die as H100 + "A100 SXM": 19.5, + "A100 PCIe": 19.5, + "A100-SXM4-80GB": 19.5, + "A100-SXM4-40GB": 19.5, + "B200": 90.0, + "B100": 60.0, + "L40S": 36.6, + "L40": 36.6, + "A10": 31.2, +} + + +def _lookup_gpu(table, gpu_name, override, unit_name): + """Look up a GPU spec from a table, with optional override.""" + if override is not None: + return override + for key, val in table.items(): + if key.lower() in gpu_name.lower(): + return val + raise ValueError( + f"Unknown GPU '{gpu_name}'. Please specify the override flag. " + f"Known GPUs: {list(table.keys())}" + ) + + +def get_peak_bandwidth_tb_s(gpu_name, override=None): + """Return peak HBM bandwidth in TB/s.""" + return _lookup_gpu(_PEAK_BW_TB_S, gpu_name, override, "TB/s") + + +def get_peak_simt_fp32_tflops(gpu_name, override=None): + """Return peak SIMT FP32 throughput in TFLOPS.""" + return _lookup_gpu(_PEAK_SIMT_FP32_TFLOPS, gpu_name, override, "TFLOPS") + + +def tensor_size_bytes(t): + """Return the number of physical bytes backing tensor *t*. + + Dimensions with stride 0 are broadcasts — the actual storage is only 1 + element along that axis, so we count those dimensions as size 1. + """ + n_elems = 1 + for size, stride in zip(t.shape, t.stride(), strict=True): + n_elems *= size if stride != 0 else 1 + return n_elems * t.element_size() + + +def compute_problem_bytes(inputs): + """Compute total bytes read + written from actual tensors. + + The kernel reads only the state_cache rows selected by slot_idx. + When intermediate_states_buffer is present, the kernel writes intermediate + states (indexed by intermediate_slot_idx) instead of writing back to + state_cache. Otherwise it writes state_cache back in-place. + + Read: state_cache[slot_idx], x, dt, A, B, C, D, dt_bias, + slot_idx, intermediate_slot_idx (if present), z (if present) + Write: output (same shape as x), + + intermediate_states_buffer[intermediate_slot_idx] if present, + + state_cache[slot_idx] otherwise (written back in-place) + """ + state_cache = inputs["state_cache"] + slot_idx = inputs["slot_idx"] + # Only the rows selected by slot_idx are accessed (not the full cache). + state_accessed = state_cache[slot_idx] # (batch_size, nheads, dim, dstate) + state_read_bytes = tensor_size_bytes(state_accessed) + + read_bytes = state_read_bytes + for k in ["x", "dt", "A", "B", "C", "D", "dt_bias", "slot_idx"]: + read_bytes += tensor_size_bytes(inputs[k]) + if inputs.get("z") is not None: + read_bytes += tensor_size_bytes(inputs["z"]) + if inputs.get("intermediate_slot_idx") is not None: + read_bytes += tensor_size_bytes(inputs["intermediate_slot_idx"]) + + write_bytes = tensor_size_bytes(inputs["x"]) # output (same shape/dtype as x) + if inputs.get("intermediate_states_buffer") is not None: + # Kernel writes to intermediate_states_buffer[intermediate_slot_idx], + # not back to state_cache. + istate = inputs["intermediate_states_buffer"] + islot = inputs["intermediate_slot_idx"] + write_bytes += tensor_size_bytes(istate[islot]) + else: + # No intermediate buffer: state is written back in-place + write_bytes += state_read_bytes + + return read_bytes + write_bytes + + +def compute_problem_flops(inputs): + """Count FP32 FLOPs for the SSU kernel (SIMT, not tensor-core). + + Equations (per batch, step, head, dim_row): + Pre-compute: + dt_val = dt + dt_bias 1 add + dt_val = softplus(dt_val) 3 ops (exp, add, log) + dA = exp(A * dt_val) 1 mul + 1 exp = 2 + dtx = dt_val * x 1 mul + total: 7 per (B,T,H,D) + + State update (per dstate element): + h = h * dA + B * dtx 2 mul + 1 add = 3 + total: 3 per (B,T,H,D,N) + + Output reduction: + y += C[n] * h[n] for n in dstate 1 mul + 1 add = 2 per N + y += D * x 1 mul + 1 add = 2 + total: 2*N + 2 per (B,T,H,D) + + Optional gating: + sig_z = 1 / (1 + exp(-z)) 3 (exp, add, div) + y = y * z * sig_z 2 mul + total: 5 per (B,T,H,D) + """ + x = inputs["x"] + has_z = inputs.get("z") is not None + + # x shape: (batch, [T,] nheads, dim) — T dimension present in MTP mode + if x.dim() == 4: + batch_size, T_val, nheads, dim = x.shape + else: + batch_size, nheads, dim = x.shape + T_val = 1 + + state_cache = inputs["state_cache"] + dstate = state_cache.shape[-1] + + outer = batch_size * T_val * nheads * dim # (B, T, H, D) iterations + + flops_precompute = 7 * outer + flops_state = 3 * outer * dstate + flops_output = (2 * dstate + 2) * outer + flops_gating = 5 * outer if has_z else 0 + + return flops_precompute + flops_state + flops_output + flops_gating + + +def create_benchmark_inputs( + batch_size, + nheads, + dim, + dstate, + ngroups, + state_dtype, + device="cuda", + mtp=0, + generate_intermediate_states_buffer=False, +): + """Create test inputs for benchmarking.""" + cache_steps = None if mtp == 0 else mtp + return create_test_inputs( + batch_size=batch_size, + nheads=nheads, + dim=dim, + dstate=dstate, + ngroups=ngroups, + input_dtype=torch.bfloat16, + weight_dtype=torch.float32, + state_dtype=state_dtype, + matrixA_dtype=torch.float32, + generate_z=False, + generate_intermediate_states_buffer=generate_intermediate_states_buffer, + cache_steps=cache_steps, + device=device, + seed=0, + ) + + +def benchmark_kernel( + name, + kernel_fn, + inputs, + cache_steps=0, + ncu=False, + rand_seed=None, + philox_rounds=10, + repeat_time_ms=1000, +): + """Benchmark a single kernel and return median time in ms.""" + print(f"\n Benchmarking {name}...") + + state = clone_preserving_strides(inputs["state_cache"]) + out = torch.empty_like(inputs["x"]) + + intermediate_states_buffer = inputs.get("intermediate_states_buffer", None) + intermediate_slot_idx = inputs.get("intermediate_slot_idx", None) + + kwargs = dict( + state=state, + x=inputs["x"], + dt=inputs["dt"], + A=inputs["A"], + B=inputs["B"], + C=inputs["C"], + D=inputs["D"], + z=None, + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + out=out, + intermediate_states_buffer=intermediate_states_buffer, + intermediate_state_indices=intermediate_slot_idx, + cache_steps=cache_steps, + rand_seed=rand_seed, + philox_rounds=philox_rounds, + disable_state_update=True, + ) + + if ncu: + kernel_fn(**kwargs) + torch.cuda.synchronize() + print(" Single invocation done (ncu mode)") + return 0.0 + + try: + measurements = bench_gpu_time( + lambda: kernel_fn(**kwargs), + dry_run_time_ms=100, + repeat_time_ms=repeat_time_ms, + ) + except RuntimeError as e: + print(f" Kernel failed: {e}") + return float("inf") + + median_time = np.median(measurements) + print(f" Median time: {median_time:.3f} ms") + return median_time + + +def make_flashinfer_wrapper(algorithm="auto"): + """Wrap FlashInfer's selective_state_update to match the common benchmark interface.""" + + def wrapper( + state, + x, + dt, + A, + B, + C, + D, + z, + dt_bias, + dt_softplus, + state_batch_indices, + pad_slot_id, + out, + intermediate_states_buffer, + intermediate_state_indices, + cache_steps, + rand_seed, + philox_rounds, + disable_state_update, + ): + flashinfer_selective_state_update( + state=state, + x=x, + dt=dt, + A=A, + B=B, + C=C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + state_batch_indices=state_batch_indices, + pad_slot_id=pad_slot_id, + out=out, + intermediate_states_buffer=intermediate_states_buffer, + cache_steps=cache_steps, + intermediate_state_indices=intermediate_state_indices, + algorithm=algorithm, + rand_seed=rand_seed, + philox_rounds=philox_rounds, + disable_state_update=disable_state_update, + ) + + return wrapper + + +def run_measurement( + batch_size, + nheads, + dim, + ngroups, + dstate, + state_dtype, + mtp=0, + generate_intermediate_states_buffer=False, + ncu=False, + philox_rounds=None, + repeat_time_ms=1000, +): + """Run benchmarks on all kernels and return results dict.""" + inputs = create_benchmark_inputs( + batch_size=batch_size, + nheads=nheads, + dim=dim, + dstate=dstate, + ngroups=ngroups, + state_dtype=state_dtype, + mtp=mtp, + generate_intermediate_states_buffer=generate_intermediate_states_buffer, + ) + + cache_steps = inputs.get("cache_steps", 0) + + # Stochastic rounding: create rand_seed tensor when philox_rounds is set + rand_seed = None + effective_philox_rounds = 10 + if philox_rounds is not None: + rand_seed = torch.tensor(42, dtype=torch.int64, device="cuda") + effective_philox_rounds = philox_rounds + + kernels = { + "flashinfer_simple": make_flashinfer_wrapper(algorithm="simple"), + "flashinfer_vertical": make_flashinfer_wrapper(algorithm="vertical"), + "flashinfer_horizontal": make_flashinfer_wrapper(algorithm="horizontal"), + "flashinfer_auto": make_flashinfer_wrapper(algorithm="auto"), + } + + results = {} + for name, fn in kernels.items(): + median_time = benchmark_kernel( + name, + fn, + inputs, + cache_steps=cache_steps, + ncu=ncu, + rand_seed=rand_seed, + philox_rounds=effective_philox_rounds, + repeat_time_ms=repeat_time_ms, + ) + results[name] = median_time + + return results, inputs + + +# -- Main -- + +parser = argparse.ArgumentParser( + description="Benchmark selective_state_update — % of Speed-of-Light (SOL)" +) +parser.add_argument( + "--ncu", + action="store_true", + help="NCU profiling mode: single invocation per kernel, no warmup or timing", +) +parser.add_argument( + "-b", + "--batch", + type=int, + nargs="+", + default=None, + help="Batch size(s) to benchmark (default: powers of 2 from 1 to 2048)", +) +parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help="Output image file path (default: auto-generated in benchmarks/img/)", +) +parser.add_argument( + "--dtype", + type=str, + nargs="+", + default=["bf16", "f32"], + help=( + "State dtype(s) to benchmark (default: bf16 f32). " + "Supported: bf16, f16, f32, or f16-philox-N for stochastic rounding " + "with N philox rounds (e.g. f16-philox-5)" + ), +) +parser.add_argument( + "--dstate", + type=int, + default=128, + help="State dimension (default: 128)", +) +parser.add_argument( + "--mtp", + type=int, + default=6, + help="Number of MTP (cache) steps (default: 6)", +) +parser.add_argument( + "-r", + "--repeat", + type=int, + default=1000, + help="Repeat time in milliseconds for benchmarking (default: 1000)", +) +parser.add_argument( + "--peak-bw", + type=float, + default=None, + help="GPU peak HBM bandwidth in TB/s (auto-detected from GPU name if omitted)", +) +parser.add_argument( + "--peak-flops", + type=float, + default=None, + help="GPU peak SIMT FP32 throughput in TFLOPS (auto-detected from GPU name if omitted)", +) +args = parser.parse_args() + +# Powers of two from 1 to 2048 +batch_sizes = args.batch if args.batch is not None else [2**i for i in range(12)] + +_dtype_name_to_torch = { + "bf16": torch.bfloat16, + "f16": torch.float16, + "f32": torch.float32, +} + + +def parse_dtype_spec(spec): + """Parse a dtype spec like 'bf16', 'f16', or 'f16-philox-5'. + + Returns (display_name, torch_dtype, philox_rounds_or_None). + """ + m = re.match(r"^(bf16|f16|f32)-philox-(\d+)$", spec) + if m: + base, rounds = m.group(1), int(m.group(2)) + return spec, _dtype_name_to_torch[base], rounds + if spec not in _dtype_name_to_torch: + raise ValueError( + f"Unknown dtype spec '{spec}'. " + "Expected bf16, f16, f32, or -philox-" + ) + return spec, _dtype_name_to_torch[spec], None + + +state_dtypes = [parse_dtype_spec(s) for s in args.dtype] +mtp_value = args.mtp + +# Fixed kernel parameters (matching bench_ssu_sweep_mtp.py) +NHEADS = 64 +DIM = 64 +NGROUPS = 8 + +# Resolve peak specs +gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "Unknown GPU" +peak_bw_tb_s = get_peak_bandwidth_tb_s(gpu_name, args.peak_bw) +peak_flops_tflops = get_peak_simt_fp32_tflops(gpu_name, args.peak_flops) + +all_results = [] + +print("=" * 80) +print("COLLECTING BENCHMARK RESULTS — SOL% (MTP ENABLED)") +print("=" * 80) +print(f"GPU: {gpu_name}") +print(f"Peak HBM bandwidth: {peak_bw_tb_s:.2f} TB/s") +print(f"Peak SIMT FP32: {peak_flops_tflops:.1f} TFLOPS") +print(f"Batch sizes to test: {batch_sizes}") +print(f"State dtypes: {[name for name, _, _ in state_dtypes]}") +print(f"MTP (cache_steps): {mtp_value}, dstate: {args.dstate}") +if args.ncu: + print("NCU mode: single invocation, no warmup/timing") +print("=" * 80) + +for state_dtype_name, state_dtype_torch, philox_rounds in state_dtypes: + for batch_size in batch_sizes: + print( + f"\n Running benchmark for batch_size={batch_size}, " + f"state_dtype={state_dtype_name}, mtp={mtp_value}, dstate={args.dstate}" + ) + + results, inputs = run_measurement( + batch_size=batch_size, + nheads=NHEADS, + dim=DIM, + ngroups=NGROUPS, + dstate=args.dstate, + state_dtype=state_dtype_torch, + mtp=mtp_value, + generate_intermediate_states_buffer=True, + ncu=args.ncu, + philox_rounds=philox_rounds, + repeat_time_ms=args.repeat, + ) + + if not results: + print(f" Warning: No results returned for batch_size={batch_size}") + continue + + problem_bytes = compute_problem_bytes(inputs) + problem_flops = compute_problem_flops(inputs) + + # SOL time = memory time + compute time + # Memory and compute are not overlapped in this kernel, so total + # ideal time is the sum of both. + sol_mem_time_ms = problem_bytes / (peak_bw_tb_s * 1e9) # TB/s → bytes/ms + sol_compute_time_ms = problem_flops / ( + peak_flops_tflops * 1e9 + ) # TFLOPS → FLOPs/ms + sol_time_ms = sol_mem_time_ms + sol_compute_time_ms + + for kernel_name, median_time_ms in results.items(): + if median_time_ms > 0 and median_time_ms != float("inf"): + sol_pct = sol_time_ms / median_time_ms * 100.0 + else: + sol_pct = 0.0 + + all_results.append( + { + "batch_size": batch_size, + "state_dtype": state_dtype_name, + "kernel": kernel_name, + "avg_time_ms": median_time_ms, + "sol_pct": sol_pct, + } + ) + +# Create DataFrame +df = pd.DataFrame(all_results) + +print("\n" + "=" * 80) +print("BENCHMARK RESULTS SUMMARY") +print("=" * 80) + +if args.ncu or df.empty: + if df.empty: + print("No results collected!") + sys.exit(0) +else: + print(f"\nGPU: {gpu_name}") + print(f"Peak HBM bandwidth: {peak_bw_tb_s:.2f} TB/s") + print(f"Peak SIMT FP32: {peak_flops_tflops:.1f} TFLOPS") + print(f"State dtypes: {[name for name, _, _ in state_dtypes]}") + print(f"MTP (cache_steps): {mtp_value}") + + unique_dtypes = df["state_dtype"].unique() + num_dtypes = len(unique_dtypes) + fig, axes = plt.subplots(num_dtypes, 1, figsize=(10, 5 * num_dtypes), squeeze=False) + + for dtype_idx, dtype_name in enumerate(unique_dtypes): + df_dtype = df[df["state_dtype"] == dtype_name] + + # Print time table + df_time_pivot = df_dtype.pivot( + index="batch_size", columns="kernel", values="avg_time_ms" + ) + print( + f"\nMedian time (ms) by batch size and kernel (state_dtype={dtype_name}):" + ) + print(df_time_pivot.to_csv()) + + # Print SOL% table + df_sol_pivot = df_dtype.pivot( + index="batch_size", columns="kernel", values="sol_pct" + ) + print(f"\nSOL% by batch size and kernel (state_dtype={dtype_name}):") + print(df_sol_pivot.to_csv()) + + ax = axes[dtype_idx, 0] + prop_cycle = plt.rcParams["axes.prop_cycle"] + default_colors = prop_cycle.by_key()["color"] + + df_plot = df_sol_pivot.reset_index() + kernel_columns = [col for col in df_plot.columns if col != "batch_size"] + + x_positions = np.arange(len(df_plot["batch_size"])) + x_tick_labels = [f"{x}" for x in df_plot["batch_size"].values] + num_cols = len(kernel_columns) + bar_width = 0.8 / max(num_cols, 1) + + for idx, col in enumerate(kernel_columns): + sol_vals = df_plot[col] + offset = (idx - num_cols / 2 + 0.5) * bar_width + bars = ax.bar( + x_positions + offset, + sol_vals, + bar_width, + color=default_colors[idx % len(default_colors)], + label=col, + alpha=0.7, + ) + for bar, y in zip(bars, sol_vals, strict=True): + if y > 0: + ax.text( + bar.get_x() + bar.get_width() / 2, + y, + f"{y:.0f}%", + ha="center", + va="bottom", + fontsize=7, + rotation=0, + ) + + ax.set_xlabel("Batch Size") + ax.set_ylabel("% SOL") + ax.set_xticks(x_positions) + ax.set_xticklabels(x_tick_labels) + ax.grid(True, alpha=0.3, axis="y") + ax.axhline(y=100, color="red", linestyle="--", alpha=0.5, label="100% SOL") + dstate_subtitle = f", dstate={args.dstate}" if args.dstate != 128 else "" + ax.set_title( + f"State dtype: {dtype_name}, MTP={mtp_value}{dstate_subtitle} " + f"— Peak BW: {peak_bw_tb_s:.2f} TB/s, Peak SIMT FP32: {peak_flops_tflops:.0f} TFLOPS" + ) + ax.set_ylim([0, None]) + ax.legend(loc="best", fontsize=8) + + dstate_title = f", dstate={args.dstate}" if args.dstate != 128 else "" + fig.suptitle( + f"SSU % of Speed-of-Light (MTP={mtp_value}{dstate_title}) [{gpu_name}]", + fontsize=14, + fontweight="bold", + ) + plt.tight_layout(rect=[0, 0, 1, 0.97]) + + if args.output: + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + else: + gpu_name_clean = gpu_name.replace(" ", "_").replace("/", "_") + dtype_str = "_".join(name for name, _, _ in state_dtypes) + output_filename = ( + f"sol_vs_batch_size_mtp{mtp_value}_{dtype_str}_{gpu_name_clean}.png" + ) + img_dir = Path(__file__).parent / "img" + img_dir.mkdir(exist_ok=True) + output_path = img_dir / output_filename + plt.savefig(output_path, dpi=300) + print(f"\nPlot saved to: {output_path}") diff --git a/flashinfer/mamba/selective_state_update.py b/flashinfer/mamba/selective_state_update.py index 95124ea109..e1df4d6542 100644 --- a/flashinfer/mamba/selective_state_update.py +++ b/flashinfer/mamba/selective_state_update.py @@ -269,6 +269,11 @@ def selective_state_update( # No stochastic rounding when rand_seed is None philox_rounds = 0 + if intermediate_states_buffer is not None and dst_state_batch_indices is not None: + raise ValueError( + "intermediate_states_buffer and dst_state_batch_indices are mutually exclusive" + ) + if out is None: output = torch.empty_like(x) else: @@ -302,7 +307,8 @@ def selective_state_update( elif algorithm == "horizontal": algorithm_int = 3 elif algorithm == "async_horizontal": - algorithm_int = 4 + # Backward compat: async_horizontal is now merged into simple + algorithm_int = 1 else: raise ValueError(f"Unknown algorithm: {algorithm}") diff --git a/include/flashinfer/mamba/common.cuh b/include/flashinfer/mamba/common.cuh index ad4bcdfb86..6c61c56ef1 100644 --- a/include/flashinfer/mamba/common.cuh +++ b/include/flashinfer/mamba/common.cuh @@ -40,6 +40,9 @@ inline constexpr unsigned largestPow2Divisor(unsigned v) { return v ? (v & (~v + // so it is always valid even when N * sizeof(T) is not a power of 2 (e.g. 3 × 2 = 6). template struct alignas(largestPow2Divisor(N * sizeof(T))) PackedAligned { + static_assert(N > 0, + "PackedAligned instantiated with N == 0; " + "ensure getVectorLoadSizeForFullUtilization() returns > 0"); T val[N]; static constexpr int count = N; using dtype = T; diff --git a/include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh b/include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh index 94734a15e5..cb7c77f760 100644 --- a/include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh +++ b/include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh @@ -4,7 +4,6 @@ #include #include -#include #include #include "../utils.cuh" @@ -12,7 +11,6 @@ #include "common.cuh" #include "conversion.cuh" #include "create_tensor_map.cuh" -#include "kernel_selective_state_update_mtp_async_horizontal.cuh" #include "kernel_selective_state_update_mtp_simple.cuh" #ifdef FLASHINFER_MAMBA_ENABLE_SM100 #include "kernel_selective_state_update_mtp_horizontal.cuh" @@ -23,6 +21,24 @@ namespace flashinfer::mamba::mtp { using namespace conversion; +// Dispatch to the largest CTAS_PER_HEAD in the sequence where +// (a) kDim / CTAS >= kMinRows (compile-time) and (b) ctas_per_head >= CTAS (runtime). +// The sequence must be in descending order and end with 1 to guarantee a match. +template +__host__ void dispatchCtasPerHead(int ctas_per_head, F&& launch, + std::integer_sequence) { + if constexpr (kDim / CTAS >= kMinRows) { + if (ctas_per_head >= CTAS) { + launch.template operator()(); + return; + } + } + if constexpr (sizeof...(Rest) > 0) { + dispatchCtasPerHead(ctas_per_head, std::forward(launch), + std::integer_sequence{}); + } +} + template void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, SSUAlgorithm algorithm, @@ -40,6 +56,10 @@ void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, SSUAlgorithm "MTP selective_state_update only supports 'auto', 'simple', 'vertical', " "'horizontal', or 'async_horizontal' algorithm, got ", static_cast(algorithm)); + // kAsyncHorizontal is now merged into kSimple + if (algorithm == SSUAlgorithm::kAsyncHorizontal) { + algorithm = SSUAlgorithm::kSimple; + } // ── Auto algorithm selection ────────────────────────────────────────────── if (algorithm == SSUAlgorithm::kAuto) { #ifdef FLASHINFER_MAMBA_ENABLE_SM100 @@ -73,73 +93,6 @@ void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, SSUAlgorithm FLASHINFER_CHECK_ALIGNMENT(params.intermediate_states, alignof(load_state_t)); } - // ── Async Horizontal MTP kernel (SM80+, no TMA) ───────────────────────── - if (algorithm == SSUAlgorithm::kAsyncHorizontal) { - constexpr int NUM_WARPS = 4; - constexpr int kRowsPerPass = NUM_WARPS * async_horiz::ROWS_PER_WARP; - - FLASHINFER_CHECK(params.nheads % params.ngroups == 0, "nheads (", params.nheads, - ") must be divisible by ngroups (", params.ngroups, - ") for async_horizontal algorithm"); - FLASHINFER_CHECK(!scaleState, - "async_horizontal algorithm does not support scaled (quantized) state"); - FLASHINFER_CHECK(params.cu_seqlens == nullptr, - "async_horizontal algorithm does not support varlen (cu_seqlens)"); - - // Determine CTAS_PER_HEAD: split DIM across grid.z for more parallelism at small batch - int const total_tiles = params.batch * params.nheads; - int const num_sms = GetCudaMultiProcessorCount(); - - // Pick CTAS_PER_HEAD to saturate the GPU: ratio = target_ctas / total_tiles, - // clamped to [1, max_ctas]. DIM_PER_CTA must be >= ROWS_PER_PASS. - // With 128 threads and 48 regs/thread, registers limit to 10 blocks/SM. - constexpr int kBlocksPerSM = 10; - constexpr int kMaxCtas = DIM / kRowsPerPass; - int const target_ctas = num_sms * kBlocksPerSM; - int const ctas_per_head = std::clamp(target_ctas / max(total_tiles, 1), 1, kMaxCtas); - - auto launch = [&]() { - constexpr int DIM_PER_CTA = DIM / CTAS_PER_HEAD; - static_assert(DIM % CTAS_PER_HEAD == 0); - static_assert(DIM_PER_CTA % kRowsPerPass == 0); - - dispatchRatio( - params, std::integer_sequence{}, [&]() { - constexpr int DSTATE_PAD = padDstate(DSTATE); - using sram_t = AsyncHorizontalStorage; - constexpr size_t smem_size = sizeof(sram_t); - - auto func = selective_state_update_kernel_async_horizontal_mtp< - input_t, weight_t, matrixA_t, state_t, stateIndex_t, NTOKENS_MTP, DIM, DSTATE, - HEADS_PER_GROUP, PHILOX_ROUNDS, NUM_WARPS, CTAS_PER_HEAD>; - - dim3 grid(params.batch, params.nheads, CTAS_PER_HEAD); - dim3 block(warpSize, NUM_WARPS); - - FLASHINFER_CUDA_CHECK( - cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - func<<>>(params); - }); - }; - - // Dispatch to the largest instantiated CTAS_PER_HEAD <= ctas_per_head. - // Use if constexpr to avoid compiling invalid template instantiations. - if constexpr (DIM / 4 >= kRowsPerPass) { - if (ctas_per_head >= 4) { - launch.template operator()<4>(); - return; - } - } - if constexpr (DIM / 2 >= kRowsPerPass) { - if (ctas_per_head >= 2) { - launch.template operator()<2>(); - return; - } - } - launch.template operator()<1>(); - return; - } - // ── Vertical MTP kernel (SM100+ only) ──────────────────────────────────── #ifdef FLASHINFER_MAMBA_ENABLE_SM100 if (algorithm == SSUAlgorithm::kVertical) { @@ -149,6 +102,9 @@ void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, SSUAlgorithm constexpr int kVerticalDimAlignment = warpSize; // epilogue: elemsPerThread = DIM / warpSize FLASHINFER_CHECK(DIM % kVerticalDimAlignment == 0, "Vertical kernel requires DIM divisible by 32 (warpSize), got DIM=", DIM); + FLASHINFER_CHECK( + DSTATE % warpSize == 0, + "Vertical kernel requires DSTATE divisible by 32 (warpSize), got DSTATE=", DSTATE); FLASHINFER_CHECK(!scaleState, "vertical algorithm does not support scaled (quantized) state"); FLASHINFER_CHECK(params.cu_seqlens == nullptr, "vertical algorithm does not support varlen (cu_seqlens)"); @@ -212,19 +168,15 @@ void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, SSUAlgorithm FLASHINFER_CHECK(params.nheads % params.ngroups == 0, "nheads (", params.nheads, ") must be divisible by ngroups (", params.ngroups, ") for horizontal algorithm"); - constexpr int kHorizontalDimAlignment = - horiz::NUM_COMPUTE_WARPS_PER_GROUP * horiz::ROWS_PER_WARP; - FLASHINFER_CHECK(DIM % kHorizontalDimAlignment == 0, - "Horizontal kernel requires DIM divisible by ", kHorizontalDimAlignment, - " (NUM_COMPUTE_WARPS_PER_GROUP * ROWS_PER_WARP), got DIM=", DIM); - FLASHINFER_CHECK(!scaleState, "horizontal algorithm does not support scaled (quantized) state"); - FLASHINFER_CHECK(params.cu_seqlens == nullptr, - "horizontal algorithm does not support varlen (cu_seqlens)"); - constexpr int NUM_IN_STAGES = 2; // TMA_STATE_ROWS: rows of DIM per TMA transaction. Must be a multiple of ROWS_PER_PASS. // Larger values = fewer barrier syncs but more smem per pipeline stage. constexpr int TMA_STATE_ROWS = 2 * horiz::ROWS_PER_PASS; + FLASHINFER_CHECK(DIM % TMA_STATE_ROWS == 0, "Horizontal kernel requires DIM divisible by ", + TMA_STATE_ROWS, " (TMA_STATE_ROWS = 2 * ROWS_PER_PASS), got DIM=", DIM); + FLASHINFER_CHECK(!scaleState, "horizontal algorithm does not support scaled (quantized) state"); + FLASHINFER_CHECK(params.cu_seqlens == nullptr, + "horizontal algorithm does not support varlen (cu_seqlens)"); dispatchRatio( params, std::integer_sequence{}, [&]() { @@ -246,34 +198,30 @@ void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, SSUAlgorithm dim3 grid(params.batch, params.nheads / HEADS_PER_CTA); dim3 block(warpSize, horiz::NUM_WARPS); - // TMA state descriptor: tile by (BANK_CYCLE_ELEMS, TMA_STATE_ROWS). - // Sub-tile-major smem layout eliminates bank conflicts for non-power-of-2 DSTATE. - // TMA's OOB fill zeros out partial tiles (e.g. cols 96–127 when DSTATE=96). - constexpr int BANK_CYCLE_ELEMS = - 32 * (int)sizeof(uint32_t) / (int)sizeof(state_t); // 64 for f16/bf16 - constexpr int DSTATE_SMEM = sram_t::DSTATE_SMEM; - // State/B/C tensor maps use FILL_NONE: OOB elements are NOT written - // to smem, so pre-zeroed padding columns remain zero. + // TMA state descriptor: single wide tile of DSTATE_PAD columns. + // DSTATE_PAD is DSTATE rounded up to 128 bytes (32 banks), eliminating + // bank conflicts. OOB padding is handled in registers, not smem. + constexpr int DSTATE_PAD = sram_t::DSTATE_PAD; auto state_tensor = tma::buildNdDescriptor( typeid(state_t), /*shapes*/ {DSTATE, DIM, params.nheads, params.state_cache_size}, /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, - /*tiles*/ {BANK_CYCLE_ELEMS, TMA_STATE_ROWS, 1, 1}, params.state); + /*tiles*/ {DSTATE_PAD, TMA_STATE_ROWS, 1, 1}, params.state); - // B/C: tile by DSTATE_SMEM to match padded smem layout. + // B/C: tile by DSTATE_PAD to match padded smem layout. auto B_tensor = tma::buildNdDescriptor( typeid(input_t), {(uint64_t)DSTATE, (uint64_t)params.ngroups, (uint64_t)params.ntokens_mtp, (uint64_t)params.batch}, {1, (uint64_t)DSTATE, (uint64_t)params.B_stride_mtp, (uint64_t)params.B_stride_batch}, - {DSTATE_SMEM, 1, NTOKENS_MTP, 1}, params.B); + {DSTATE_PAD, 1, NTOKENS_MTP, 1}, params.B); auto C_tensor = tma::buildNdDescriptor( typeid(input_t), {(uint64_t)DSTATE, (uint64_t)params.ngroups, (uint64_t)params.ntokens_mtp, (uint64_t)params.batch}, {1, (uint64_t)DSTATE, (uint64_t)params.C_stride_mtp, (uint64_t)params.C_stride_batch}, - {DSTATE_SMEM, 1, NTOKENS_MTP, 1}, params.C); + {DSTATE_PAD, 1, NTOKENS_MTP, 1}, params.C); auto x_tensor = tma::buildNdDescriptor( typeid(input_t), @@ -297,41 +245,57 @@ void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, SSUAlgorithm "recompile with FLASHINFER_MAMBA_ENABLE_SM100"); #endif - // ── Simple MTP kernel ──────────────────────────────────────────────────── - constexpr int numWarps = 4; - constexpr int stateRowsPerWarpPerStage = 4; - constexpr int stateRowsPerBlockPerStage = stateRowsPerWarpPerStage * numWarps; - int const total_tiles = params.batch * params.nheads; - int const num_sms = GetCudaMultiProcessorCount(); - - dim3 block(warpSize, numWarps); - if (total_tiles < num_sms * 2) { - // Small tile per CTA (stateRowsPerBlockPerStage * DSTATE): split dim across grid.z for GPU - // occupancy - int const dim_tiles = (DIM + stateRowsPerBlockPerStage - 1) / stateRowsPerBlockPerStage; - dim3 grid(params.batch, params.nheads, dim_tiles); - auto func = selective_state_update_kernel_simple_mtp< - input_t, weight_t, matrixA_t, state_t, stateIndex_t, state_scale_t, NTOKENS_MTP, DIM, - DSTATE, stateRowsPerBlockPerStage, PHILOX_ROUNDS, numWarps>; - using sram_t = - SharedStorageSimple; - constexpr size_t smem_size = sizeof(sram_t); - FLASHINFER_CUDA_CHECK( - cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - func<<>>(params); - } else { - // Full tile per CTA (DIM * DSTATE): enough blocks for occupancy, no dim splitting needed - dim3 grid(params.batch, params.nheads); - auto func = selective_state_update_kernel_simple_mtp; - using sram_t = SharedStorageSimple; - constexpr size_t smem_size = sizeof(sram_t); - FLASHINFER_CUDA_CHECK( - cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - func<<>>(params); + // ── Simple MTP kernel (SM80+, cp.async, no TMA) ───────────────────────── + { + constexpr int NUM_WARPS = 4; + constexpr int kRowsPerPass = NUM_WARPS * simple_horiz::ROWS_PER_WARP; + + FLASHINFER_CHECK(params.nheads % params.ngroups == 0, "nheads (", params.nheads, + ") must be divisible by ngroups (", params.ngroups, ") for simple algorithm"); + // Determine CTAS_PER_HEAD: split DIM across grid.z for more parallelism at small batch + int const total_tiles = params.batch * params.nheads; + int const num_sms = GetCudaMultiProcessorCount(); + + // Pick CTAS_PER_HEAD to saturate the GPU: ratio = target_ctas / total_tiles, + // clamped to [1, max_ctas]. DIM_PER_CTA must be >= ROWS_PER_PASS. + // With 128 threads and 48 regs/thread, registers limit to 10 blocks/SM. + constexpr int kBlocksPerSM = 10; + constexpr int kMaxCtas = DIM / kRowsPerPass; + int const target_ctas = num_sms * kBlocksPerSM; + int const ctas_per_head = std::clamp(target_ctas / std::max(total_tiles, 1), 1, kMaxCtas); + + auto launch = [&]() { + constexpr int DIM_PER_CTA = DIM / CTAS_PER_HEAD; + static_assert(DIM % CTAS_PER_HEAD == 0); + static_assert(DIM_PER_CTA % kRowsPerPass == 0); + + dispatchRatio( + params, std::integer_sequence{}, [&]() { + constexpr int DSTATE_PAD = padDstate(DSTATE); + constexpr int kRowsPerPassLocal = NUM_WARPS * simple_horiz::ROWS_PER_WARP; + constexpr int kNumPasses = DIM_PER_CTA / kRowsPerPassLocal; + constexpr int kStateStages = (kNumPasses == 1) ? 1 : 2; + using sram_t = SimpleStorage; + constexpr size_t smem_size = sizeof(sram_t); + + auto func = selective_state_update_kernel_simple_mtp< + input_t, weight_t, matrixA_t, state_t, stateIndex_t, state_scale_t, NTOKENS_MTP, + DIM, DSTATE, HEADS_PER_GROUP, PHILOX_ROUNDS, NUM_WARPS, CTAS_PER_HEAD>; + + dim3 grid(params.batch, params.nheads, CTAS_PER_HEAD); + dim3 block(warpSize, NUM_WARPS); + + FLASHINFER_CUDA_CHECK( + cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + func<<>>(params); + }); + }; + + // Dispatch to the largest instantiated CTAS_PER_HEAD <= ctas_per_head. + // if constexpr inside dispatchCtasPerHead avoids compiling invalid instantiations. + dispatchCtasPerHead(ctas_per_head, launch, + std::integer_sequence{}); } } diff --git a/include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh b/include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh index 0fc82d539d..328a8dcbe6 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh @@ -63,16 +63,16 @@ static constexpr int ROWS_PER_PASS = NUM_COMPUTE_WARPS_PER_GROUP * ROWS_PER_WARP template struct GroupStorageHorizontal { - // Sub-tile-major smem layout: each sub-tile of BANK_CYCLE_ELEMS columns is stored - // contiguously per row, eliminating bank conflicts for non-power-of-2 DSTATE. - static constexpr int BANK_CYCLE_ELEMS = 32 * sizeof(uint32_t) / sizeof(state_t); - static constexpr int NUM_STATE_SUBTILES = (DSTATE + BANK_CYCLE_ELEMS - 1) / BANK_CYCLE_ELEMS; - static constexpr int DSTATE_SMEM = NUM_STATE_SUBTILES * BANK_CYCLE_ELEMS; - - alignas(128) input_t B[TOKENS_MTP][DSTATE_SMEM]; - alignas(128) input_t C[TOKENS_MTP][DSTATE_SMEM]; - alignas( - 128) state_t state_in[NUM_IN_STAGES][NUM_STATE_SUBTILES * TMA_STATE_ROWS * BANK_CYCLE_ELEMS]; + // Pad DSTATE to next multiple of 32 banks (128 bytes) to eliminate bank conflicts. + // TMA loads a single wide tile of DSTATE_PAD columns; OOB columns (DSTATE..DSTATE_PAD-1) + // are skipped by TMA (FILL_NONE) and zeroed in registers at load time. + static constexpr int BANK_CYCLE_BYTES = 32 * sizeof(uint32_t); // 128 bytes + static constexpr int DSTATE_PAD = (DSTATE * (int)sizeof(state_t) + BANK_CYCLE_BYTES - 1) / + BANK_CYCLE_BYTES * BANK_CYCLE_BYTES / (int)sizeof(state_t); + + alignas(128) input_t B[TOKENS_MTP][DSTATE_PAD]; + alignas(128) input_t C[TOKENS_MTP][DSTATE_PAD]; + alignas(128) state_t state_in[NUM_IN_STAGES][TMA_STATE_ROWS * DSTATE_PAD]; alignas(128) input_t x[HEADS_PER_CTA][TOKENS_MTP][DIM]; float dt[HEADS_PER_CTA][TOKENS_MTP]; float out[TOKENS_MTP][DIM]; @@ -102,31 +102,28 @@ __device__ __forceinline__ void role_load_horizontal( // These are merged with the first state tile into a single barrier transaction, // eliminating a separate bar_input_full barrier and letting TMA instructions // stream back-to-back without serialization. - constexpr int DSTATE_SMEM = SramT::DSTATE_SMEM; - constexpr int bytesBCX = 2 * NTOKENS * DSTATE_SMEM * (int)sizeof(input_t) + + constexpr int DSTATE_PAD = SramT::DSTATE_PAD; + constexpr int bytesBCX = 2 * NTOKENS * DSTATE_PAD * (int)sizeof(input_t) + HEADS_PER_CTA * NTOKENS * DIM * (int)sizeof(input_t); + // B/C/x TMA loads are always issued (even for pad slots — output must be valid). + // They are merged into bar_state_in_full[0] alongside the first state tile. if (lane == 0) { - if constexpr (!IS_PAD) { - cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.B[0][0], &tensorB, 0, kv_group, 0, batch, - sram.bar_state_in_full[0]); - cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.C[0][0], &tensorC, 0, kv_group, 0, batch, - sram.bar_state_in_full[0]); + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.B[0][0], &tensorB, 0, kv_group, 0, batch, + sram.bar_state_in_full[0]); + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.C[0][0], &tensorC, 0, kv_group, 0, batch, + sram.bar_state_in_full[0]); #pragma unroll - for (int h = 0; h < HEADS_PER_CTA; h++) { - cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.x[h][0][0], &tensorX, 0, base_head + h, - 0, batch, sram.bar_state_in_full[0]); - } + for (int h = 0; h < HEADS_PER_CTA; h++) { + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.x[h][0][0], &tensorX, 0, base_head + h, 0, + batch, sram.bar_state_in_full[0]); } } // ── Pipeline state_in loads (TMA_STATE_ROWS per transaction) ────────── - // Sub-tile-major layout: each chunk issues numStateSubTiles TMA loads of - // (BANK_CYCLE_ELEMS, TMA_STATE_ROWS), stored contiguously per sub-tile. - // This gives bank-cycle-aligned rows, eliminating bank conflicts. - constexpr int BANK_CYCLE_ELEMS = SramT::BANK_CYCLE_ELEMS; - constexpr int numStateSubTiles = SramT::NUM_STATE_SUBTILES; - constexpr int bytesChunk = - numStateSubTiles * TMA_STATE_ROWS * BANK_CYCLE_ELEMS * (int)sizeof(state_t); + // Single wide TMA load of DSTATE_PAD columns per chunk. + // OOB padding columns are handled in registers, not smem. + // State is only loaded for non-pad slots; pad slots use zero state in registers. + constexpr int bytesChunk = TMA_STATE_ROWS * DSTATE_PAD * (int)sizeof(state_t); uint32_t parity_empty[NUM_IN_STAGES] = {}; // all start at phase 0 #pragma unroll for (int h = 0; h < HEADS_PER_CTA; h++) { @@ -139,17 +136,15 @@ __device__ __forceinline__ void role_load_horizontal( if (lane == 0) { if constexpr (!IS_PAD) { -#pragma unroll - for (int st = 0; st < numStateSubTiles; st++) { - cde::cp_async_bulk_tensor_4d_global_to_shared( - &sram.state_in[slot][st * TMA_STATE_ROWS * BANK_CYCLE_ELEMS], &tensorState, - st * BANK_CYCLE_ELEMS, tl * TMA_STATE_ROWS, head, state_batch, - sram.bar_state_in_full[slot]); - } + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state_in[slot][0], &tensorState, 0, + tl * TMA_STATE_ROWS, head, state_batch, + sram.bar_state_in_full[slot]); int const bytes = (h == 0 && tl == 0) ? bytesBCX + bytesChunk : bytesChunk; cuda::device::barrier_arrive_tx(sram.bar_state_in_full[slot], 1, bytes); } else { - cuda::device::barrier_arrive_tx(sram.bar_state_in_full[slot], 1, 0); + // Pad slot: no state TMA, but first iteration includes the B/C/x bytes + int const bytes = (h == 0 && tl == 0) ? bytesBCX : 0; + cuda::device::barrier_arrive_tx(sram.bar_state_in_full[slot], 1, bytes); } } } @@ -174,8 +169,8 @@ __device__ __forceinline__ void role_update_state_horizontal(SramT& sram, int la constexpr int rowsPerWarp = horiz::ROWS_PER_WARP; constexpr int numTmaLoads = DIM / TMA_STATE_ROWS; constexpr int subPassesPerTma = TMA_STATE_ROWS / horiz::ROWS_PER_PASS; - constexpr int DSTATE_SMEM = SramT::DSTATE_SMEM; - constexpr int stateValuesPerThread = DSTATE_SMEM / lanesPerRow; + constexpr int DSTATE_PAD = SramT::DSTATE_PAD; + constexpr int stateValuesPerThread = DSTATE_PAD / lanesPerRow; static_assert(DSTATE % lanesPerRow == 0, "DSTATE must be divisible by lanesPerRow"); static_assert(DIM % TMA_STATE_ROWS == 0, "DIM must be divisible by TMA_STATE_ROWS"); static_assert(TMA_STATE_ROWS % horiz::ROWS_PER_PASS == 0, @@ -211,17 +206,11 @@ __device__ __forceinline__ void role_update_state_horizontal(SramT& sram, int la auto const icache_idx = intermediate_state_indices ? (int64_t)intermediate_state_indices[batch] : state_batch; - // Logical column within DSTATE_SMEM (for B/C/state smem access and global store bounds) + // Logical column within DSTATE_PAD (for B/C/state smem access and global store bounds) auto baseCol = [&](int t, int e) -> int { return t * elemsPerTile + member * elemsPerTileMember + e; }; - // Smem state index: sub-tile-major layout [subtile][row][col_within_subtile] - auto smemStateIdx = [&](int t, int sram_row, int e) -> int { - return t * TMA_STATE_ROWS * elemsPerTile + sram_row * elemsPerTile + - member * elemsPerTileMember + e; - }; - // Output pointers (for epilogue) auto* __restrict__ output = reinterpret_cast(params.output); auto const* __restrict__ z_ptr = reinterpret_cast(params.z); @@ -290,13 +279,18 @@ __device__ __forceinline__ void role_update_state_horizontal(SramT& sram, int la int const sram_row = sp * horiz::ROWS_PER_PASS + compute_warp * rowsPerWarp + group; int const dd = tl * TMA_STATE_ROWS + sram_row; // global DIM row - // Load state from smem (TMA zero-fills columns beyond DSTATE) + // Load state from smem; zero padding beyond DSTATE in registers float2 rState[numTiles][pairsPerTileMember]; #pragma unroll for (int t = 0; t < numTiles; t++) { #pragma unroll for (int p = 0; p < pairsPerTileMember; p++) { - rState[t][p] = toFloat2(&sram.state_in[slot][smemStateIdx(t, sram_row, p * 2)]); + int const c0 = baseCol(t, p * 2); + if (c0 >= DSTATE || IS_PAD) { + rState[t][p] = {0.f, 0.f}; + } else { + rState[t][p] = toFloat2(&sram.state_in[slot][sram_row * DSTATE_PAD + c0]); + } } } @@ -331,6 +325,7 @@ __device__ __forceinline__ void role_update_state_horizontal(SramT& sram, int la #pragma unroll for (int p = 0; p < pairsPerTileMember; p++) { int const c0 = baseCol(t, p * 2); + if (c0 >= DSTATE) continue; float2 const B2 = toFloat2(&B_step[c0]); float2 const C2 = toFloat2(&C_step[c0]); float2 dBx; @@ -352,8 +347,8 @@ __device__ __forceinline__ void role_update_state_horizontal(SramT& sram, int la } // Advance step pointers (addition instead of multiply) - B_step += DSTATE_SMEM; - C_step += DSTATE_SMEM; + B_step += DSTATE_PAD; + C_step += DSTATE_PAD; x_step += DIM; dt_step += 1; out_step += DIM; @@ -494,41 +489,6 @@ __global__ void __launch_bounds__(horiz::NUM_WARPS * 32, 6) } __syncthreads(); - // ── Zero-fill smem padding for DSTATE < DSTATE_SMEM ──────────────────── - // TMA uses FILL_NONE: OOB elements are NOT written to smem. Pre-zeroing - // the padding columns ensures they read as zero, eliminating OOB divergence. - if constexpr (sram_t::DSTATE_SMEM > DSTATE) { - constexpr int PAD = sram_t::DSTATE_SMEM - DSTATE; - int const tid = warp * warpSize + lane; - int const numThreads = horiz::NUM_WARPS * warpSize; - - // Zero B/C padding: B[step][DSTATE..DSTATE_SMEM-1], same for C - constexpr int bc_pad_total = NTOKENS * PAD; - for (int i = tid; i < bc_pad_total; i += numThreads) { - int const step = i / PAD; - int const col = DSTATE + i % PAD; - sram.B[step][col] = input_t(0); - sram.C[step][col] = input_t(0); - } - - // Zero state_in padding: sub-tile-major layout, last sub-tile's OOB columns - // Layout: [slot][subtile][row][BANK_CYCLE_ELEMS] - // OOB columns are in the last sub-tile at offset DSTATE % BANK_CYCLE_ELEMS - constexpr int BCE = sram_t::BANK_CYCLE_ELEMS; - constexpr int lastSubTileOffset = DSTATE % BCE; // first OOB column within last sub-tile - constexpr int padPerRow = BCE - lastSubTileOffset; - constexpr int lastST = sram_t::NUM_STATE_SUBTILES - 1; - constexpr int state_pad_total = NUM_IN_STAGES * TMA_STATE_ROWS * padPerRow; - for (int i = tid; i < state_pad_total; i += numThreads) { - int const slot = i / (TMA_STATE_ROWS * padPerRow); - int const rem = i % (TMA_STATE_ROWS * padPerRow); - int const row = rem / padPerRow; - int const col = lastSubTileOffset + rem % padPerRow; - sram.state_in[slot][lastST * TMA_STATE_ROWS * BCE + row * BCE + col] = state_t(0); - } - } - __syncthreads(); - // ── Warp role dispatch ───────────────────────────────────────────────── auto dispatch = [&]() { if (warp < horiz::NUM_COMPUTE_WARPS_PER_GROUP) { diff --git a/include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh b/include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh index dd03fb7a18..cb9da757a3 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh @@ -1,8 +1,37 @@ -#ifndef FLASHINFER_MAMBA_KERNEL_SELECTIVE_STATE_UPDATE_MTP_SIMPLE_CUH_ -#define FLASHINFER_MAMBA_KERNEL_SELECTIVE_STATE_UPDATE_MTP_SIMPLE_CUH_ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Simple MTP kernel for selective_state_update. +// Uses cp.async instead of TMA — works on SM80+ (Ampere, Hopper, Blackwell). +// All warps are compute warps (no dedicated TMA warp). +// State is loaded directly into registers from global memory. +// Shared memory B/C rows are padded to avoid bank conflicts for non-power-of-2 DSTATE. +// +// Execution flow: +// 1. All warps cooperatively cp.async B/C/x/dt into smem. +// 2. Each thread loads its state columns from global memory directly into rState[] registers. +// 3. cp_async_wait_group<0>() + __syncthreads() — single sync. +// 4. Step loop: pure register compute + smem reads for B/C/x. No further syncs until epilogue. + +#pragma once #include #include +#include +#include #include #include @@ -11,542 +40,626 @@ #include "../vec_dtypes.cuh" #include "common.cuh" #include "conversion.cuh" +#include "ssu_mtp_common.cuh" namespace flashinfer::mamba::mtp { using namespace conversion; -template -struct SharedStorageSimple { - static constexpr bool scaleState = !std::is_same_v; - alignas(alignof(PackedAligned)) input_t x[TOKENS_MTP][ROWS_PER_BLOCK]; - alignas(alignof(PackedAligned)) float out[TOKENS_MTP][ROWS_PER_BLOCK]; - alignas(alignof(PackedAligned)) input_t z[TOKENS_MTP][ROWS_PER_BLOCK]; - alignas(alignof(PackedAligned)) input_t B[TOKENS_MTP][DSTATE]; - alignas(alignof(PackedAligned)) input_t C[TOKENS_MTP][DSTATE]; - alignas(alignof(PackedAligned)) state_t state[STATE_ROWS][DSTATE]; - alignas(alignof(PackedAligned)) - std::conditional_t state_scale[STATE_ROWS]; -}; - -// Grid: (batch_or_n_sequences, nheads, cdiv(DIM, ROWS_PER_BLOCK)) -// When ROWS_PER_BLOCK == DIM, degenerates to the non-tiled case (blockIdx.z == 0 always). -template -__global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams params) { - constexpr bool scaleState = !std::is_same_v; - auto* __restrict__ output = reinterpret_cast(params.output); - auto* __restrict__ state = reinterpret_cast(params.state); - auto* __restrict__ intermediate_states = reinterpret_cast(params.intermediate_states); - - auto const* __restrict__ x = reinterpret_cast(params.x); - auto const* __restrict__ dt = reinterpret_cast(params.dt); - auto const* __restrict__ A = reinterpret_cast(params.A); - auto const* __restrict__ B = reinterpret_cast(params.B); - auto const* __restrict__ C = reinterpret_cast(params.C); - auto const* __restrict__ D = reinterpret_cast(params.D); - auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); - auto const* __restrict__ z = reinterpret_cast(params.z); - auto const* __restrict__ state_batch_indices = - reinterpret_cast(params.state_batch_indices); - auto const* __restrict__ intermediate_state_indices = - reinterpret_cast(params.intermediate_state_indices); - auto const* __restrict__ cu_seqlens = - reinterpret_cast(params.cu_seqlens); - auto const* __restrict__ num_accepted_tokens = - reinterpret_cast(params.num_accepted_tokens); - auto const* __restrict__ dst_state_batch_indices = - reinterpret_cast(params.dst_state_batch_indices); - bool const dt_softplus = params.dt_softplus; - - int const nheads = params.nheads; - int const ngroups = params.ngroups; - - auto const seq_idx = blockIdx.x; - auto const head = blockIdx.y; - auto const dim_offset = blockIdx.z * ROWS_PER_BLOCK; - auto const group = head / (nheads / ngroups); - auto lane = threadIdx.x % warpSize; - auto warp = threadIdx.y; - - int bos; - int seq_len; - bool const has_cu_seqlens = (cu_seqlens != nullptr); - if (has_cu_seqlens) { - bos = __ldg(&cu_seqlens[seq_idx]); - int eos = __ldg(&cu_seqlens[seq_idx + 1]); - seq_len = eos - bos; - if (seq_len <= 0) return; - } else { - bos = 0; - seq_len = TOKENS_MTP; - } +// Simple kernel constants. +namespace simple_horiz { +static constexpr int LANES_PER_ROW = 8; +static constexpr int ROWS_PER_WARP = warpSize / LANES_PER_ROW; +static constexpr int64_t SKIP_WRITE_STATE = -1; +} // namespace simple_horiz + +// Pad DSTATE to next multiple of 32 banks (128 bytes) to avoid bank conflicts. +template +constexpr int padDstate(int dstate) { + constexpr int alignment = 128; // 32 banks * 4 bytes/bank + int row_bytes = dstate * (int)sizeof(T); + int padded_bytes = (row_bytes + alignment - 1) / alignment * alignment; + return padded_bytes / (int)sizeof(T); +} - int init_token_idx = 0; - if (num_accepted_tokens) { - int num_accepted = __ldg(&num_accepted_tokens[seq_idx]); - init_token_idx = max(num_accepted - 1, 0); - } +// ============================================================================= +// Shared memory layout for simple kernel. +// Includes state_in buffer for cp.async prefetch from global memory. +// ============================================================================= + +template +struct SimpleStorage { + alignas(128) input_t B[NTOKENS][DSTATE_PAD]; + alignas(128) input_t C[NTOKENS][DSTATE_PAD]; + alignas(128) input_t x[NTOKENS][DIM_PER_CTA]; + float dt[NTOKENS]; + float out[NTOKENS][DIM_PER_CTA]; + // Precomputed per-step destination batch indices for state writes. + // -1 means "skip this step". + int64_t state_dst_slots[NTOKENS]; + // State prefetch buffer: cp.async loads state here before the barrier. + // Single stage for DPC=16 (1 pass), 2 stages for DPC>16 (pipelined). + alignas(128) state_t state_in[STATE_STAGES][ROWS_PER_PASS][DSTATE_PAD]; +}; - // State scale pointer (only used when scaleState == true) - [[maybe_unused]] auto* __restrict__ state_scale = - reinterpret_cast(params.state_scale); +// ============================================================================= +// cp.async helpers: 8-byte and 16-byte async copy from global to shared memory. +// ============================================================================= - // Load device-side Philox seed once into a register - [[maybe_unused]] int64_t const rand_seed = params.rand_seed ? *params.rand_seed : 0; +__device__ __forceinline__ void cp_async_16B(void* __restrict__ smem_dst, + void const* __restrict__ gmem_src) { + unsigned int smem_addr = __cvta_generic_to_shared(smem_dst); + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n" ::"r"(smem_addr), "l"(gmem_src) + : "memory"); +} - int64_t state_batch; - if (state_batch_indices) { - state_batch = static_cast( - state_batch_indices[seq_idx * params.state_batch_indices_stride_batch + - init_token_idx * params.state_batch_indices_stride_T]); - } else { - state_batch = static_cast(seq_idx); - } - auto const intermediate_cache_idx = - intermediate_state_indices ? intermediate_state_indices[seq_idx] : state_batch; - auto const state_ptr_offset = state_batch * params.state_stride_batch + head * DIM * DSTATE; - state += state_ptr_offset; - if constexpr (scaleState) { - state_scale += state_batch * params.state_scale_stride_batch + head * DIM; +// ============================================================================= +// Cooperative state cp.async: all threads load state via flat 16-byte chunks. +// ============================================================================= + +template +__device__ __forceinline__ void cp_async_state_cooperative(SramT& sram, int lane, int warp, + int state_stage, int dim_offset, + state_t const* __restrict__ state_ptr, + int64_t state_base) { + constexpr int STATE_PACK = 16 / sizeof(state_t); + constexpr int state_packs_per_row = DSTATE / STATE_PACK; + constexpr int num_state_chunks = ROWS_PER_PASS * DSTATE / STATE_PACK; + int const flat_tid = warp * warpSize + lane; + constexpr int num_threads = NUM_WARPS * warpSize; + + for (int i = flat_tid; i < num_state_chunks; i += num_threads) { + int const row = i / state_packs_per_row; + int const col = (i % state_packs_per_row) * STATE_PACK; + int const dd = dim_offset + row; + cp_async_16B(&sram.state_in[state_stage][row][col], &state_ptr[state_base + dd * DSTATE + col]); } +} - int64_t const x_base = has_cu_seqlens ? (int64_t)bos * params.x_stride_batch - : (int64_t)seq_idx * params.x_stride_batch; - int64_t const x_tstride = has_cu_seqlens ? params.x_stride_batch : params.x_stride_mtp; - - int64_t const dt_base = has_cu_seqlens ? (int64_t)bos * params.dt_stride_batch - : (int64_t)seq_idx * params.dt_stride_batch; - int64_t const dt_tstride = has_cu_seqlens ? params.dt_stride_batch : params.dt_stride_mtp; - +// ============================================================================= +// Cooperative async load: all warps cp.async B, C, x, state into smem. +// dt is loaded via LDG (needs softplus computation). +// ============================================================================= + +template +__device__ __forceinline__ void load_simple(SramT& sram, int lane, int warp, + SelectiveStateMTPParams const& params, int seq_idx, + int head, int kv_group, int dim_offset, int bos, + int seq_len, int64_t state_batch, int state_stage) { + int const flat_tid = warp * warpSize + lane; + constexpr auto num_threads = NUM_WARPS * warpSize; + + auto const* __restrict__ dt_ptr = reinterpret_cast(params.dt); + auto const* __restrict__ dt_bias_ptr = reinterpret_cast(params.dt_bias); + + // Varlen: tokens are at (bos + step) with stride_batch; non-varlen: seq_idx with stride_mtp + bool const has_cu_seqlens = (params.cu_seqlens != nullptr); int64_t const B_base = has_cu_seqlens ? (int64_t)bos * params.B_stride_batch : (int64_t)seq_idx * params.B_stride_batch; int64_t const B_tstride = has_cu_seqlens ? params.B_stride_batch : params.B_stride_mtp; - int64_t const C_base = has_cu_seqlens ? (int64_t)bos * params.C_stride_batch : (int64_t)seq_idx * params.C_stride_batch; int64_t const C_tstride = has_cu_seqlens ? params.C_stride_batch : params.C_stride_mtp; + int64_t const x_base = has_cu_seqlens ? (int64_t)bos * params.x_stride_batch + : (int64_t)seq_idx * params.x_stride_batch; + int64_t const x_tstride = has_cu_seqlens ? params.x_stride_batch : params.x_stride_mtp; + int64_t const dt_base = has_cu_seqlens ? (int64_t)bos * params.dt_stride_batch + : (int64_t)seq_idx * params.dt_stride_batch; + int64_t const dt_tstride = has_cu_seqlens ? params.dt_stride_batch : params.dt_stride_mtp; - int64_t const out_base = has_cu_seqlens ? (int64_t)bos * params.out_stride_batch - : (int64_t)seq_idx * params.out_stride_batch; - int64_t const out_tstride = has_cu_seqlens ? params.out_stride_batch : params.out_stride_mtp; - - int64_t const z_base = z ? (has_cu_seqlens ? (int64_t)bos * params.z_stride_batch - : (int64_t)seq_idx * params.z_stride_batch) - : 0; - int64_t const z_tstride = z ? (has_cu_seqlens ? params.z_stride_batch : params.z_stride_mtp) : 0; - - constexpr auto stateRowsPerWarpPerStage = 4; - constexpr auto stageRows = stateRowsPerWarpPerStage * numWarps; - - extern __shared__ __align__(128) char smem[]; - auto& sram = *reinterpret_cast*>(smem); - - static constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); - using load_state_t = PackedAligned; - using load_input_t = PackedAligned; - using load_weight_t = PackedAligned; - - auto const A_value = toFloat(A[head]); - auto const d_value = D ? toFloat(D[head]) : 0.f; - auto const dt_bias_value = dt_bias ? toFloat(dt_bias[head]) : 0.f; - - // Loop over multiple tokens - if (warp == 0) { // Load x: gmem -> smem - for (int mtp_step = 0; mtp_step < TOKENS_MTP; mtp_step++) { - for (auto d = lane * load_input_t::count; d < ROWS_PER_BLOCK; - d += warpSize * load_input_t::count) { - if (dim_offset + d < DIM) { - auto* dst = reinterpret_cast(&sram.x[mtp_step][d]); - if (mtp_step < seq_len) { - *dst = *reinterpret_cast( - &x[x_base + mtp_step * x_tstride + head * DIM + dim_offset + d]); - } else { - *dst = make_zeros(); - } + { + // ── Per-warp cp.async loads to avoid cross-array bank conflicts ── + // B/C/x/dt are always loaded (even for pad slots — output must be valid). + // State is only loaded for non-pad slots; pad slots use zero state. + constexpr int INPUT_PACK = 16 / sizeof(input_t); // 8 for bf16 + constexpr int STATE_PACK = 16 / sizeof(state_t); // 4 for f32, 8 for f16/bf16 + static_assert(DSTATE % INPUT_PACK == 0, "DSTATE must be divisible by input pack size"); + static_assert(DSTATE % STATE_PACK == 0, "DSTATE must be divisible by state pack size"); + static_assert(DIM_PER_CTA % INPUT_PACK == 0, + "DIM_PER_CTA must be divisible by input pack size"); + + constexpr int B_packs_per_row = DSTATE / INPUT_PACK; + constexpr int C_packs_per_row = DSTATE / INPUT_PACK; + constexpr int x_packs_per_row = DIM_PER_CTA / INPUT_PACK; + constexpr int state_packs_per_row = DSTATE / STATE_PACK; + + // Warp 0: load B[step][DSTATE] → sram.B[step][DSTATE_PAD] + if (warp == 0) { + auto const* __restrict__ B_ptr = reinterpret_cast(params.B); + constexpr int num_B_chunks = NTOKENS * DSTATE / INPUT_PACK; + for (int i = lane; i < num_B_chunks; i += warpSize) { + int const step = i / B_packs_per_row; + int const col = (i % B_packs_per_row) * INPUT_PACK; + if (step < seq_len) { + cp_async_16B(&sram.B[step][col], + &B_ptr[B_base + step * B_tstride + kv_group * DSTATE + col]); } } } - } else if (warp == 1) { // Load B: gmem -> smem - for (int mtp_step = 0; mtp_step < TOKENS_MTP; mtp_step++) { - for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { - auto* dst = reinterpret_cast(&sram.B[mtp_step][i]); - if (mtp_step < seq_len) { - *dst = *reinterpret_cast( - &B[B_base + mtp_step * B_tstride + group * DSTATE + i]); - } else { - *dst = make_zeros(); + + // Warp 1: load C[step][DSTATE] → sram.C[step][DSTATE_PAD] + if (warp == 1) { + auto const* __restrict__ C_ptr = reinterpret_cast(params.C); + constexpr int num_C_chunks = NTOKENS * DSTATE / INPUT_PACK; + for (int i = lane; i < num_C_chunks; i += warpSize) { + int const step = i / C_packs_per_row; + int const col = (i % C_packs_per_row) * INPUT_PACK; + if (step < seq_len) { + cp_async_16B(&sram.C[step][col], + &C_ptr[C_base + step * C_tstride + kv_group * DSTATE + col]); } } } - } else if (warp == 2) { // Load z: gmem -> smem - for (int mtp_step = 0; mtp_step < TOKENS_MTP; mtp_step++) { - for (auto d = lane * load_input_t::count; d < ROWS_PER_BLOCK; - d += warpSize * load_input_t::count) { - if (dim_offset + d < DIM) { - auto* dst = reinterpret_cast(&sram.z[mtp_step][d]); - if (z && mtp_step < seq_len) { - *dst = *reinterpret_cast( - &z[z_base + mtp_step * z_tstride + head * DIM + dim_offset + d]); - } else { - *dst = make_zeros(); - } + + // All warps load x: each warp handles different tokens to avoid bank conflicts. + // With DPC=16 bf16, each row is 32 bytes (8 banks). One row per warp per iteration + // means zero bank aliasing within any warp-wide cp.async instruction. + { + auto const* __restrict__ x_ptr = reinterpret_cast(params.x); + for (int step = warp; step < seq_len; step += NUM_WARPS) { + for (int col = lane * INPUT_PACK; col < DIM_PER_CTA; col += warpSize * INPUT_PACK) { + cp_async_16B(&sram.x[step][col], + &x_ptr[x_base + step * x_tstride + head * DIM + dim_offset + col]); } } } - } - // Load C: gmem -> smem - else if (warp == 3) { - for (int mtp_step = 0; mtp_step < TOKENS_MTP; mtp_step++) { - for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { - auto* dst = reinterpret_cast(&sram.C[mtp_step][i]); - if (mtp_step < seq_len) { - *dst = *reinterpret_cast( - &C[C_base + mtp_step * C_tstride + group * DSTATE + i]); - } else { - *dst = make_zeros(); - } - } + + // All 4 warps: tiled state load (bank-conflict-free, 8-byte cp.async) + // Only load state for non-pad slots; pad slots leave state uninitialized + // (zeroed in registers by update_state_simple). + if constexpr (!IS_PAD) { + auto const* __restrict__ state_ptr = reinterpret_cast(params.state); + auto const state_base = state_batch * params.state_stride_batch + head * DIM * DSTATE; + cp_async_state_cooperative( + sram, lane, warp, state_stage, dim_offset, state_ptr, state_base); + } + + // Load dt[step] via LDG (needs softplus computation, can't use cp.async) + if (flat_tid < seq_len) { + int const step = flat_tid; + float dt_bias_val = dt_bias_ptr ? toFloat(dt_bias_ptr[head]) : 0.f; + float dt_val = toFloat(dt_ptr[dt_base + step * dt_tstride + head]); + dt_val += dt_bias_val; + if (params.dt_softplus) dt_val = thresholded_softplus(dt_val); + sram.dt[step] = dt_val; } } - float rdt[TOKENS_MTP]; - for (int step = 0; step < TOKENS_MTP; step++) { - if (step < seq_len) { - auto dt_value = dt_bias_value + toFloat(dt[dt_base + step * dt_tstride + head]); - if (dt_softplus) { - dt_value = thresholded_softplus(dt_value); - } - rdt[step] = dt_value; + // Precompute per-step destination slot indices for state writes. + // Three mutually exclusive modes: + // 1. dst_state_batch_indices → varlen: prefetch per-step indices (or -1 for pad_slot_id) + // 2. intermediate_states → MTP cache: consecutive slots within icache entry + // 3. neither → only write final state at last step + if (flat_tid < NTOKENS) { + int const step = flat_tid; + constexpr int64_t SKIP = simple_horiz::SKIP_WRITE_STATE; + auto const* __restrict__ dst_state_batch_indices = + reinterpret_cast(params.dst_state_batch_indices); + if (IS_PAD || step >= seq_len) { + sram.state_dst_slots[step] = SKIP; + } else if (dst_state_batch_indices) { + // Varlen: read per-step destination index, mark pad slots as SKIP + auto const dst_idx = static_cast( + dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch + + step * params.dst_state_batch_indices_stride_T]); + sram.state_dst_slots[step] = (dst_idx == params.pad_slot_id) ? SKIP : dst_idx; + } else if (params.intermediate_states) { + // MTP cache: consecutive step slots within the icache entry + auto const* __restrict__ intermediate_state_indices = + reinterpret_cast(params.intermediate_state_indices); + auto const icache_idx = intermediate_state_indices + ? static_cast(intermediate_state_indices[seq_idx]) + : state_batch; + sram.state_dst_slots[step] = icache_idx * params.cache_steps + step; } else { - rdt[step] = 0.f; + // Final-state only: write at last step + sram.state_dst_slots[step] = + (step == seq_len - 1 && params.update_state) ? state_batch : SKIP; } } - __syncthreads(); + // Commit all cp.async and wait for completion + asm volatile("cp.async.commit_group;\n" ::: "memory"); + asm volatile("cp.async.wait_group 0;\n" ::: "memory"); +} - bool const has_dst_indices = (dst_state_batch_indices != nullptr); - bool const has_intermediate = (intermediate_states != nullptr); - - for (auto dBegin = 0; dBegin < ROWS_PER_BLOCK; dBegin += stageRows) { - // Load state gmem -> smem - for (int warpRow = 0; warpRow < stateRowsPerWarpPerStage; warpRow++) { - auto dd = warp * stateRowsPerWarpPerStage + warpRow; - auto d = dBegin + dd; - if (dim_offset + d < DIM) { - if (state_batch != params.pad_slot_id) { - for (int i = lane * load_state_t::count; i < DSTATE; - i += warpSize * load_state_t::count) { - auto* dst = reinterpret_cast(&sram.state[dd][i]); - *dst = *reinterpret_cast(&state[(dim_offset + d) * DSTATE + i]); - } - } - } - } - // Load state_scale gmem -> smem (contiguous across warpRows) - if constexpr (scaleState) { - for (int warpRow = lane; warpRow < stateRowsPerWarpPerStage; warpRow += warpSize) { - auto dd = warp * stateRowsPerWarpPerStage + warpRow; - auto d = dBegin + dd; - if (dim_offset + d < DIM && state_batch != params.pad_slot_id) { - sram.state_scale[dd] = state_scale[dim_offset + d]; - } - } - } +// ============================================================================= +// State update: DSTATE traversal with state in registers. +// ============================================================================= + +template +__device__ __forceinline__ void update_state_simple(SramT& sram, int lane, int warp, + SelectiveStateMTPParams const& params, + int seq_idx, int head, int dim_offset, + int64_t state_batch, int bos, int seq_len, + float A_val, float D_val) { + constexpr bool scaleState = !std::is_same_v; + constexpr int lanesPerRow = simple_horiz::LANES_PER_ROW; + constexpr int rowsPerWarp = simple_horiz::ROWS_PER_WARP; + constexpr int ROWS_PER_PASS = NUM_WARPS * rowsPerWarp; + constexpr int DSTATE_PADDED = nextPow2(DSTATE); + constexpr int stateValuesPerThread = DSTATE_PADDED / lanesPerRow; + + constexpr int bankSize = sizeof(uint32_t); + constexpr int stateValuesPerBank = bankSize / sizeof(state_t); + constexpr int numBanks = 32; + constexpr int sramReadsPerThreadPerTile = numBanks / lanesPerRow; + constexpr int elemsPerTileMember = sramReadsPerThreadPerTile * stateValuesPerBank; + constexpr int elemsPerTile = elemsPerTileMember * lanesPerRow; + constexpr int numTiles = stateValuesPerThread / elemsPerTileMember; + using packed_tile_t = PackedAligned; + + static_assert(DSTATE % lanesPerRow == 0, "DSTATE must be divisible by lanesPerRow"); + static_assert(DIM_PER_CTA % ROWS_PER_PASS == 0, "DIM_PER_CTA must be divisible by ROWS_PER_PASS"); + static_assert(elemsPerTileMember % 2 == 0, "elemsPerTileMember must be even for f32x2"); + constexpr int pairsPerTileMember = elemsPerTileMember / 2; + + int const member = lane % lanesPerRow; + int const group = lane / lanesPerRow; - // Compute how many input_t elements to pack per SRAM load based on DSTATE/warpSize ratio - constexpr auto stateValuesPerThread = DSTATE / warpSize; - // We will be loading two-banks worth of input_t at a time instead of 1 in order to reduce the - // load on LSU. - constexpr auto maxPackedElements = sizeof(uint64_t) / sizeof(input_t); - constexpr auto packedSramLdInputElements = - (stateValuesPerThread >= maxPackedElements) ? maxPackedElements : stateValuesPerThread; - static_assert(stateValuesPerThread % packedSramLdInputElements == 0, - "stateValuesPerThread must be divisible by packedSramLdInputElements"); - using packed_input_t = PackedAligned; - float rState[stateValuesPerThread]; - packed_input_t rB; - packed_input_t rC; - - for (int warpRow = 0; warpRow < stateRowsPerWarpPerStage; warpRow++) { - auto dd = warp * stateRowsPerWarpPerStage + warpRow; - auto d = dim_offset + dBegin + dd; // global DIM index - - if (d >= DIM) break; - - // Load state smem -> rmem - // There is a bank conflict here, but we are not in a hot loop and we must align the state - // indices with the input indices - float state_decode_scale = 1.f; - if constexpr (scaleState) { - if (state_batch != params.pad_slot_id) state_decode_scale = toFloat(sram.state_scale[dd]); - } - for (int ii = 0; ii < stateValuesPerThread; ii++) { - int i = lane * packed_input_t::count + - (ii / packed_input_t::count) * warpSize * packed_input_t::count + - (ii % packed_input_t::count); - rState[ii] = (state_batch != params.pad_slot_id && i < DSTATE) - ? toFloat(sram.state[dd][i]) * state_decode_scale - : 0.f; - } + [[maybe_unused]] int64_t const rand_seed = params.rand_seed ? *params.rand_seed : 0; - for (int step = 0; step < TOKENS_MTP; step++) { - if (step >= seq_len) break; + auto* __restrict__ state_ptr = reinterpret_cast(params.state); + + // Unified state write path: pick destination pointer and stride based on mode. + // The per-step slot indices are precomputed in sram.state_dst_slots[]. + // For istate: dst_slot = icache_idx * cache_steps + step, so the flat slot + // index works with both state stride (nheads*DIM*DSTATE) and scale stride (nheads*DIM). + auto* __restrict__ write_state_ptr = state_ptr; + int64_t write_state_stride = params.state_stride_batch; + [[maybe_unused]] auto* __restrict__ write_scale_ptr = + reinterpret_cast(params.state_scale); + [[maybe_unused]] int64_t write_scale_stride = params.state_scale_stride_batch; + if (params.intermediate_states) { + write_state_ptr = reinterpret_cast(params.intermediate_states); + write_state_stride = (int64_t)params.nheads * DIM * DSTATE; + write_scale_ptr = reinterpret_cast(params.intermediate_state_scales); + write_scale_stride = (int64_t)params.nheads * DIM; + } + + // Logical column helpers + auto baseCol = [&](int t, int e) -> int { + return t * elemsPerTile + member * elemsPerTileMember + e; + }; - float x_value = toFloat(sram.x[step][d - dim_offset]); - float out_value = d_value * x_value * int(lane == 0); // first lane has the value + // Output pointers (for epilogue) + auto* __restrict__ output = reinterpret_cast(params.output); + auto const* __restrict__ z_ptr = reinterpret_cast(params.z); + // Guard: outputLoadSize is only meaningful when DIM_PER_CTA >= warpSize + constexpr auto outputLoadSize = + DIM_PER_CTA >= warpSize ? getVectorLoadSizeForFullUtilization() : 1; + using load_output_t = PackedAligned; + + // State scale pointer (only used when scaleState == true) + [[maybe_unused]] auto* __restrict__ state_scale_ptr = + reinterpret_cast(params.state_scale); - // Compute dt value for this token - auto dt_value = rdt[step]; - auto const dA = __expf(A_value * dt_value); + auto const state_ptr_offset = state_batch * params.state_stride_batch + head * DIM * DSTATE; - // Process state in groups of packed_input_t::count to match B/C bank-aligned loads - for (int ii = 0; ii < stateValuesPerThread; ii += packed_input_t::count) { - int base_i = lane * packed_input_t::count + - (ii / packed_input_t::count) * warpSize * packed_input_t::count; + constexpr int numPasses = DIM_PER_CTA / ROWS_PER_PASS; - // Bank-aligned load for B and C - rB = *reinterpret_cast(&sram.B[step][base_i]); - rC = *reinterpret_cast(&sram.C[step][base_i]); + constexpr int STATE_STAGES = (numPasses == 1) ? 1 : 2; -#pragma unroll - for (int k = 0; k < packed_input_t::count; k++) { - auto& state_value = rState[ii + k]; - auto B_value = toFloat(rB.val[k]); - auto C_value = toFloat(rC.val[k]); + for (int pass = 0; pass < numPasses; pass++) { + int const pass_row = warp * rowsPerWarp + group; // row within current pass [0, ROWS_PER_PASS) + int const local_row = pass * ROWS_PER_PASS + pass_row; + int const dd = dim_offset + local_row; // global DIM index + int const stage = pass % STATE_STAGES; - auto const dB = B_value * dt_value; - auto const new_state = state_value * dA + dB * x_value; - state_value = new_state; + // Load state from smem (prefetched via cp.async) into registers + float2 rState[numTiles][pairsPerTileMember]; - out_value += new_state * C_value; - } + // Load decode scale for this DIM row (scaleState only) + [[maybe_unused]] float state_decode_scale = 1.f; + if constexpr (scaleState) { + if constexpr (!IS_PAD) { + state_decode_scale = toFloat( + state_scale_ptr[state_batch * params.state_scale_stride_batch + head * DIM + dd]); + } + } - // Store intermediate state to smem (non-scaleState path) - if constexpr (!scaleState) { - if constexpr (sizeof(state_t) == sizeof(input_t)) { - if (has_intermediate || has_dst_indices) { - using packed_state_t = PackedAligned; - packed_state_t rStateOut; - // Philox-4x32 produces 4 random ints per call; amortize across packed elements. - [[maybe_unused]] uint32_t rand_ints[4]; #pragma unroll - for (int k = 0; k < packed_input_t::count; k++) { - if constexpr (PHILOX_ROUNDS > 0) { - // SR only applies to fp16 state, so packed count is always >= 2. - static_assert(packed_input_t::count >= 2, - "Stochastic rounding requires fp16 state (packed count >= 2)"); - if (k % 4 == 0) - philox_randint4x( - rand_seed, state_ptr_offset + d * DSTATE + base_i + k, rand_ints[0], - rand_ints[1], rand_ints[2], rand_ints[3]); - rStateOut.val[k] = cvt_rs_f16_f32(rState[ii + k], rand_ints[k % 4] & 0x1FFFu); - } else { - convertAndStore(&rStateOut.val[k], rState[ii + k]); - } - } - *reinterpret_cast(&sram.state[dd][base_i]) = rStateOut; - } - } else { - if (has_intermediate || has_dst_indices) { - // Philox-4x32 produces 4 random ints per call; amortize across packed elements. - [[maybe_unused]] uint32_t rand_ints[4]; + for (int t = 0; t < numTiles; t++) { #pragma unroll - for (int k = 0; k < packed_input_t::count; k++) { - if constexpr (PHILOX_ROUNDS > 0) { - if (k % 4 == 0) - philox_randint4x( - rand_seed, state_ptr_offset + d * DSTATE + base_i + k, rand_ints[0], - rand_ints[1], rand_ints[2], rand_ints[3]); - sram.state[dd][base_i + k] = - cvt_rs_f16_f32(rState[ii + k], rand_ints[k % 4] & 0x1FFFu); - } else { - convertAndStore(&sram.state[dd][base_i + k], rState[ii + k]); - } - } - } - } + for (int p = 0; p < pairsPerTileMember; p++) { + int const c0 = baseCol(t, p * 2); + if (c0 >= DSTATE || IS_PAD) { + rState[t][p] = {0.f, 0.f}; + } else { + rState[t][p] = toFloat2(&sram.state_in[stage][pass_row][c0]); + if constexpr (scaleState) { + float2 const decode_scale2 = {state_decode_scale, state_decode_scale}; + mul_f32x2(rState[t][p], rState[t][p], decode_scale2); } } + } + } - // For scaleState + per-step writes: quantize rState → sram.state with block scaling - if constexpr (scaleState) { - if ((has_intermediate || has_dst_indices) && state_batch != params.pad_slot_id) { - // 2-pass: compute max, then encode - float istate_max = std::numeric_limits::lowest(); - for (int ii = 0; ii < stateValuesPerThread; ii++) { - istate_max = fmaxf(istate_max, fabsf(rState[ii])); - } - istate_max = warpReduceMax(istate_max); - istate_max = __shfl_sync(UINT32_MAX, istate_max, 0); - float const ie_scale = - (istate_max == 0.f) - ? 1.f - : static_cast(std::numeric_limits::max()) / istate_max; - float const id_scale = 1.f / ie_scale; - - // Encode rState → sram.state - for (int ii = 0; ii < stateValuesPerThread; ii++) { - int i = lane * packed_input_t::count + - (ii / packed_input_t::count) * warpSize * packed_input_t::count + - (ii % packed_input_t::count); - if (i < DSTATE) { - convertAndStore(&sram.state[dd][i], rState[ii] * ie_scale); - } - } - // Store decode scale to smem for later gmem write - if (lane == 0) sram.state_scale[dd] = id_scale; - } - } + // Strength-reduce step-dependent shared memory indexing + auto const* __restrict__ B_step = &sram.B[0][0]; + auto const* __restrict__ C_step = &sram.C[0][0]; + auto const* __restrict__ x_step = &sram.x[0][0]; + float const* __restrict__ dt_step = &sram.dt[0]; + float* __restrict__ out_step = &sram.out[0][0]; + + for (int step = 0; step < NTOKENS; step++) { + if (step >= seq_len) break; + // Prefetch dst_slot early so the LDS latency is hidden by the compute below + int64_t const dst_slot = sram.state_dst_slots[step]; + float const dt_value = *dt_step; + float const dA = __expf(A_val * dt_value); + float const x_value = toFloat(x_step[local_row]); + + // f32x2 packed recurrence + float2 out2 = {0.f, 0.f}; + float2 const dA2 = {dA, dA}; + float const dtx_value = dt_value * x_value; + float2 const dtx2 = {dtx_value, dtx_value}; - out_value = warpReduceSum(out_value); - if (lane == 0) { - sram.out[step][d - dim_offset] = out_value; +#pragma unroll + for (int t = 0; t < numTiles; t++) { +#pragma unroll + for (int p = 0; p < pairsPerTileMember; p++) { + int const c0 = baseCol(t, p * 2); + if (c0 >= DSTATE) continue; + float2 const B2 = toFloat2(&B_step[c0]); + float2 const C2 = toFloat2(&C_step[c0]); + float2 dBx; + mul_f32x2(dBx, B2, dtx2); + fma_f32x2(rState[t][p], dA2, rState[t][p], dBx); + fma_f32x2(out2, rState[t][p], C2, out2); } + } + float out_value = out2.x + out2.y; - if (state_batch != params.pad_slot_id) { - if (has_dst_indices) { - auto dst_idx = static_cast( - dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch + - step * params.dst_state_batch_indices_stride_T]); - if (dst_idx != params.pad_slot_id) { - auto* dst_state_ptr = reinterpret_cast(params.state); - for (int i = lane * load_state_t::count; i < DSTATE; - i += warpSize * load_state_t::count) { - auto* src = reinterpret_cast(&sram.state[dd][i]); - *reinterpret_cast( - &dst_state_ptr[dst_idx * params.state_stride_batch + head * DIM * DSTATE + - d * DSTATE + i]) = *src; - } - if constexpr (scaleState) { - if (lane == 0) { - auto* dst_scale = reinterpret_cast(params.state_scale); - dst_scale[dst_idx * params.state_scale_stride_batch + head * DIM + d] = - sram.state_scale[dd]; - } - } - } - } else if (has_intermediate) { - // Write intermediate state smem → gmem - for (int i = lane * load_state_t::count; i < DSTATE; - i += warpSize * load_state_t::count) { - auto* src = reinterpret_cast(&sram.state[dd][i]); - auto* dst = reinterpret_cast( - &intermediate_states[intermediate_cache_idx * - params.intermediate_state_stride_batch + - step * nheads * DIM * DSTATE + head * DIM * DSTATE + - d * DSTATE + i]); - *dst = *src; - } - // Write intermediate state decode scale → gmem - if constexpr (scaleState) { - if (lane == 0) { - auto* iscales = reinterpret_cast(params.intermediate_state_scales); - iscales[intermediate_cache_idx * params.intermediate_state_scales_stride_batch + - step * nheads * DIM + head * DIM + d] = sram.state_scale[dd]; - } - } - } - } + // Reduce across lanesPerRow adjacent lanes +#pragma unroll + for (int offset = lanesPerRow / 2; offset >= 1; offset /= 2) { + out_value += __shfl_down_sync(UINT32_MAX, out_value, offset); + } + + if (member == 0) { + out_step[local_row] = out_value + D_val * x_value; } - // Update state if enabled and not padded - if (params.update_state && state_batch != params.pad_slot_id && !has_dst_indices) { - // When intermediate_states is enabled, sram.state[dd] already holds the - // stochastically-rounded (or scaled) state from the last token step's intermediate write. - // Skip the redundant Philox PRNG / re-quantization and write directly to gmem. - if (!has_intermediate) { + // Advance step pointers + B_step += DSTATE_PAD; + C_step += DSTATE_PAD; + x_step += DIM_PER_CTA; + dt_step += 1; + out_step += DIM_PER_CTA; + + // Unified state write: use precomputed slot index from sram + { + if (dst_slot != simple_horiz::SKIP_WRITE_STATE) { + [[maybe_unused]] float encode_scale = 1.f; if constexpr (scaleState) { - // 2-pass quantization: compute max, then re-encode - float new_state_max = std::numeric_limits::lowest(); - for (int ii = 0; ii < stateValuesPerThread; ii++) { - new_state_max = fmaxf(new_state_max, fabsf(rState[ii])); - } - new_state_max = warpReduceMax(new_state_max); - new_state_max = __shfl_sync(UINT32_MAX, new_state_max, 0); - float const new_encode_scale = - (new_state_max == 0.f) - ? 1.f - : static_cast(std::numeric_limits::max()) / new_state_max; - float const new_decode_scale = 1.f / new_encode_scale; - - // Re-encode state values and store to smem - for (int ii = 0; ii < stateValuesPerThread; ii++) { - int i = lane * packed_input_t::count + - (ii / packed_input_t::count) * warpSize * packed_input_t::count + - (ii % packed_input_t::count); - if (i < DSTATE) { - convertAndStore(&sram.state[dd][i], rState[ii] * new_encode_scale); - } - } - if (lane == 0) convertAndStore(&sram.state_scale[dd], new_decode_scale); - } else { - // Store to rmem -> smem - // Philox-4x32 produces 4 random ints per call; amortize across consecutive elements. + encode_scale = + computeBlockScaleEncode(rState, lane, + member); + } + auto const dst_base = + dst_slot * write_state_stride + (int64_t)head * DIM * DSTATE + (int64_t)dd * DSTATE; +#pragma unroll + for (int t = 0; t < numTiles; t++) { + int const col0 = baseCol(t, 0); + if (col0 >= DSTATE) continue; + packed_tile_t rOut; [[maybe_unused]] uint32_t rand_ints[4]; - for (int ii = 0; ii < stateValuesPerThread; ii++) { - int i = lane * packed_input_t::count + - (ii / packed_input_t::count) * warpSize * packed_input_t::count + - (ii % packed_input_t::count); - if (i < DSTATE) { - if constexpr (PHILOX_ROUNDS > 0) { - if (ii % 4 == 0) - philox_randint4x(rand_seed, state_ptr_offset + d * DSTATE + i, - rand_ints[0], rand_ints[1], rand_ints[2], - rand_ints[3]); - sram.state[dd][i] = cvt_rs_f16_f32(rState[ii], rand_ints[ii % 4] & 0x1FFFu); - } else { - convertAndStore(&sram.state[dd][i], rState[ii]); - } +#pragma unroll + for (int e = 0; e < elemsPerTileMember; e += 2) { + float2 s2 = rState[t][e / 2]; + if constexpr (scaleState) { + float2 const scale2 = {encode_scale, encode_scale}; + mul_f32x2(s2, s2, scale2); } + convertAndStoreSRHorizontal( + rOut.val[e], rOut.val[e + 1], s2.x, s2.y, rand_seed, state_ptr_offset, dd, col0, + e, rand_ints); + } + *reinterpret_cast(&write_state_ptr[dst_base + col0]) = rOut; + } + // Write decode scale + if constexpr (scaleState) { + if (member == 0) { + write_scale_ptr[dst_slot * write_scale_stride + head * DIM + dd] = 1.f / encode_scale; } } - } - // store smem -> gmem - for (int i = lane * load_state_t::count; i < DSTATE; i += warpSize * load_state_t::count) { - auto* src = reinterpret_cast(&sram.state[dd][i]); - *reinterpret_cast(&state[d * DSTATE + i]) = *src; } } - } - // Store state_scale smem -> gmem (contiguous across warpRows) - if constexpr (scaleState) { - if (params.update_state && state_batch != params.pad_slot_id && !has_dst_indices) { - for (int warpRow = lane; warpRow < stateRowsPerWarpPerStage; warpRow += warpSize) { - auto dd = warp * stateRowsPerWarpPerStage + warpRow; - auto d = dim_offset + dBegin + dd; - if (d < DIM) { - state_scale[d] = sram.state_scale[dd]; - } + } // step loop + + // Multi-pass pipeline: prefetch next pass's state into the other smem stage + if constexpr (numPasses > 1) { + if (pass < numPasses - 1) { + int const next_stage = (pass + 1) % STATE_STAGES; + int const next_dim_base = dim_offset + (pass + 1) * ROWS_PER_PASS; + + if constexpr (!IS_PAD) { + auto const* __restrict__ state_ptr_r = reinterpret_cast(params.state); + auto const state_base = state_batch * params.state_stride_batch + head * DIM * DSTATE; + cp_async_state_cooperative( + sram, lane, warp, next_stage, next_dim_base, state_ptr_r, state_base); } + asm volatile("cp.async.commit_group;\n" ::: "memory"); + asm volatile("cp.async.wait_group 0;\n" ::: "memory"); + __syncthreads(); } } - } + } // pass loop + // ── Epilogue: sync all warps, z-gate + vectorized store ─── __syncthreads(); - for (auto step = warp; step < TOKENS_MTP; step += numWarps) { - if (step >= seq_len) continue; - for (auto d = lane; d < ROWS_PER_BLOCK; d += warpSize) { - if (dim_offset + d < DIM) { - auto out_value = sram.out[step][d]; - if (z) { - float z_value = toFloat(sram.z[step][d]); + // Varlen output indexing: (bos + step) * stride_batch; non-varlen: seq_idx * stride_batch + step + // * stride_mtp + bool const has_cu_seqlens = (params.cu_seqlens != nullptr); + auto out_addr = [&](int step) -> int64_t { + return has_cu_seqlens + ? (int64_t)(bos + step) * params.out_stride_batch + head * DIM + dim_offset + : (int64_t)seq_idx * params.out_stride_batch + step * params.out_stride_mtp + + head * DIM + dim_offset; + }; + auto z_addr = [&](int step) -> int64_t { + return has_cu_seqlens ? (int64_t)(bos + step) * params.z_stride_batch + head * DIM + dim_offset + : (int64_t)seq_idx * params.z_stride_batch + step * params.z_stride_mtp + + head * DIM + dim_offset; + }; + + if constexpr (DIM_PER_CTA >= warpSize) { + // Fast path: each lane handles >= 1 element, use vectorized loads/stores + constexpr int elemsPerThreadEpilogue = DIM_PER_CTA / warpSize; + + for (int step = warp; step < seq_len; step += NUM_WARPS) { + int64_t const out_offset = out_addr(step); + int64_t const z_offset = z_addr(step); + + for (int ii = 0; ii < elemsPerThreadEpilogue; ii += load_output_t::count) { + int const d = lane * load_output_t::count + + (ii / load_output_t::count) * warpSize * load_output_t::count; + load_output_t packed_out; + load_output_t packed_z; + if (z_ptr) { + packed_z = *reinterpret_cast(&z_ptr[z_offset + d]); + } +#pragma unroll + for (int k = 0; k < load_output_t::count; k++) { + float out_value = sram.out[step][d + k]; + if (z_ptr) { + float z_value = toFloat(packed_z.val[k]); + float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); + out_value *= z_value * sig_z; + } + convertAndStore(&packed_out.val[k], out_value); + } + *reinterpret_cast(&output[out_offset + d]) = packed_out; + } + } + } else { + // Narrow path: DIM_PER_CTA < warpSize, only first DIM_PER_CTA lanes participate + for (int step = warp; step < seq_len; step += NUM_WARPS) { + if (lane < DIM_PER_CTA) { + int64_t const out_offset = out_addr(step); + float out_value = sram.out[step][lane]; + if (z_ptr) { + int64_t const z_offset = z_addr(step); + float z_value = toFloat(z_ptr[z_offset + lane]); float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); - float silu_z = z_value * sig_z; - out_value *= silu_z; + out_value *= z_value * sig_z; } - auto* dst = reinterpret_cast( - &output[out_base + step * out_tstride + head * DIM + dim_offset + d]); - convertAndStore(dst, out_value); + convertAndStore(&output[out_offset + lane], out_value); } } } } -} // namespace flashinfer::mamba::mtp +// ============================================================================= +// Kernel entry point +// Grid: (batch_or_n_sequences, nheads, CTAS_PER_HEAD) +// Block: (32, NUM_WARPS) +// ============================================================================= + +template +__global__ void __launch_bounds__(NUM_WARPS * 32) + selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams params) { + constexpr int DSTATE_PAD = padDstate(DSTATE); + constexpr int DIM_PER_CTA = DIM / CTAS_PER_HEAD; + constexpr int ROWS_PER_PASS = NUM_WARPS * simple_horiz::ROWS_PER_WARP; -#endif // FLASHINFER_MAMBA_KERNEL_SELECTIVE_STATE_UPDATE_MTP_SIMPLE_CUH_ + static_assert(DIM % CTAS_PER_HEAD == 0, "DIM must be divisible by CTAS_PER_HEAD"); + static_assert(DIM_PER_CTA % ROWS_PER_PASS == 0, "DIM_PER_CTA must be divisible by ROWS_PER_PASS"); + + constexpr int numPasses = DIM_PER_CTA / ROWS_PER_PASS; + constexpr int STATE_STAGES = (numPasses == 1) ? 1 : 2; + + extern __shared__ __align__(128) char smem[]; + using sram_t = SimpleStorage; + auto& sram = *reinterpret_cast(smem); + + int const seq_idx = blockIdx.x; + int const head = blockIdx.y; + int const cta_z = blockIdx.z; + int const dim_offset = cta_z * DIM_PER_CTA; + int const lane = threadIdx.x; + int const warp = threadIdx.y; + + // ── Varlen: compute bos, seq_len ── + auto const* __restrict__ cu_seqlens = + reinterpret_cast(params.cu_seqlens); + int bos; + int seq_len; + if (cu_seqlens) { + bos = __ldg(&cu_seqlens[seq_idx]); + int const eos = __ldg(&cu_seqlens[seq_idx + 1]); + seq_len = eos - bos; + if (seq_len <= 0) return; + } else { + bos = 0; + seq_len = NTOKENS; + } + + // ── num_accepted_tokens → init_token_idx ── + auto const* __restrict__ num_accepted_tokens = + reinterpret_cast(params.num_accepted_tokens); + int init_token_idx = 0; + if (num_accepted_tokens) { + int num_accepted = __ldg(&num_accepted_tokens[seq_idx]); + init_token_idx = max(num_accepted - 1, 0); + } + + // ── State batch index: 2D (seq_idx, init_token_idx) or 1D ── + auto const* __restrict__ state_batch_indices = + reinterpret_cast(params.state_batch_indices); + int64_t state_batch; + if (state_batch_indices) { + state_batch = static_cast( + state_batch_indices[seq_idx * params.state_batch_indices_stride_batch + + init_token_idx * params.state_batch_indices_stride_T]); + } else { + state_batch = static_cast(seq_idx); + } + bool const is_pad = (state_batch == (int64_t)params.pad_slot_id); + + int const kv_group = head / HEADS_PER_GROUP; + + // Load A and D before the barrier so global memory latency overlaps with barrier wait + auto const* __restrict__ A_ptr = reinterpret_cast(params.A); + auto const* __restrict__ D_ptr = reinterpret_cast(params.D); + float const A_val = toFloat(A_ptr[head]); + float const D_val = D_ptr ? toFloat(D_ptr[head]) : 0.f; + + auto run = [&]() { + // Phase 1: cooperative cp.async load of B/C/x/dt/state into smem + load_simple( + sram, lane, warp, params, seq_idx, head, kv_group, dim_offset, bos, seq_len, state_batch, + /*state_stage=*/0); + + // Phase 2: single sync — ensures all smem writes (cp.async + LDG dt) are visible + __syncthreads(); + + // Phase 3: compute (state in registers, B/C/x from smem) + update_state_simple( + sram, lane, warp, params, seq_idx, head, dim_offset, state_batch, bos, seq_len, A_val, + D_val); + }; + + if (is_pad) + run.template operator()(); + else + run.template operator()(); +} + +} // namespace flashinfer::mamba::mtp diff --git a/include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh b/include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh index d653b28689..2383b5c774 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh @@ -19,7 +19,8 @@ // 3 TMA load warps (one per group), and 1 shared epilogue warp. // Each CTA processes up to 3 heads from the flattened head list (across all KV groups). // -// See .plans/three_compute_groups.md for the full design document. + +#pragma once #include #include @@ -99,21 +100,18 @@ __device__ __forceinline__ void role_load(SramT& sram, int lane, CUtensorMap const& tensorC) { namespace cde = cuda::device::experimental; - // ── Load B and C ────────────────────────────────────────────────────── + // ── Load B and C (always, even for pad slots — output must be valid) ── if (lane == 0) { constexpr int bytesBC = 2 * NTOKENS * DSTATE * (int)sizeof(input_t); - if constexpr (!IS_PAD) { - cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.B[0][0], &tensorB, 0, kv_group, 0, batch, - sram.bar_BC_full); - cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.C[0][0], &tensorC, 0, kv_group, 0, batch, - sram.bar_BC_full); - cuda::device::barrier_arrive_tx(sram.bar_BC_full, warpSize, bytesBC); - } else { - cuda::device::barrier_arrive_tx(sram.bar_BC_full, warpSize, 0); - } + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.B[0][0], &tensorB, 0, kv_group, 0, batch, + sram.bar_BC_full); + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.C[0][0], &tensorC, 0, kv_group, 0, batch, + sram.bar_BC_full); + cuda::device::barrier_arrive_tx(sram.bar_BC_full, warpSize, bytesBC); } // ── Load state_in + x ──────────────────────────────────────────────── + // x is always loaded; state is only loaded for non-pad slots (pad uses zero state). constexpr int bytesState = DIM * DSTATE * (int)sizeof(state_t); constexpr int bytesX = NTOKENS * DIM * (int)sizeof(input_t); constexpr int in_slot = 0; // single head, single slot @@ -126,15 +124,13 @@ __device__ __forceinline__ void role_load(SramT& sram, int lane, cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state_in[in_slot][0], &tensorState, 0, 0, head, state_batch, sram.bar_state_in_full[in_slot]); + } - cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.x[in_slot][0][0], &tensorX, 0, head, 0, - batch, sram.bar_state_in_full[in_slot]); + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.x[in_slot][0][0], &tensorX, 0, head, 0, + batch, sram.bar_state_in_full[in_slot]); - auto const _ = - cuda::device::barrier_arrive_tx(sram.bar_state_in_full[in_slot], 1, bytesState + bytesX); - } else { - cuda::device::barrier_arrive_tx(sram.bar_state_in_full[in_slot], 1, 0); - } + int constexpr bytes = IS_PAD ? bytesX : bytesState + bytesX; + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_state_in_full[in_slot], 1, bytes); } } @@ -256,8 +252,12 @@ __device__ __forceinline__ void role_update_state(SramT& sram, int lane, int com for (int wr = 0; wr < rowsPerWarpPerPass; wr++) { int const dd = row_offset + wr; for (int ii = 0; ii < stateValuesPerThread; ii++) { - rState[wr][ii] = - toFloat(sram.state_in[in_slot][dd * DSTATE + lane * stateValuesPerThread + ii]); + if constexpr (IS_PAD) { + rState[wr][ii] = 0.f; + } else { + rState[wr][ii] = + toFloat(sram.state_in[in_slot][dd * DSTATE + lane * stateValuesPerThread + ii]); + } } } diff --git a/include/flashinfer/mamba/ssu_mtp_common.cuh b/include/flashinfer/mamba/ssu_mtp_common.cuh index bb9accba44..b07d3addc4 100644 --- a/include/flashinfer/mamba/ssu_mtp_common.cuh +++ b/include/flashinfer/mamba/ssu_mtp_common.cuh @@ -18,7 +18,10 @@ #pragma once +#include +#include #include +#include #include "conversion.cuh" @@ -141,4 +144,57 @@ __device__ __forceinline__ void convertAndStoreSRHorizontal(state_t& out0, state } } +// ============================================================================= +// computeBlockScaleEncode — compute block-scaling encode_scale for a DIM row. +// +// Finds the max absolute value across all rState pairs held by lanesPerRow +// lanes (sub-warp reduce-max), then returns encode_scale = INT_MAX / max_val. +// The caller writes quantized state as: state_int = round(rState * encode_scale) +// and stores decode_scale = 1 / encode_scale alongside. +// +// Template args: +// state_t — quantized state type (e.g. int16_t) +// numTiles — number of tiles per thread +// pairsPerTileMember — float2 pairs per tile member +// lanesPerRow — lanes cooperating on one DIM row +// elemsPerTile — logical elements per tile +// elemsPerTileMember — elements per tile member +// +// rState — register array of float2[numTiles][pairsPerTileMember] +// lane — threadIdx.x +// baseCol — lambda(t, e) -> logical column index +// ============================================================================= + +template +__device__ __forceinline__ float computeBlockScaleEncode( + float2 const (&rState)[numTiles][pairsPerTileMember], int lane, int member) { + float local_max = std::numeric_limits::lowest(); + int const lane_member_0 = lane & ~(lanesPerRow - 1); + + auto baseCol = [&](int t, int e) -> int { + return t * elemsPerTile + member * elemsPerTileMember + e; + }; + +#pragma unroll + for (int t = 0; t < numTiles; t++) { +#pragma unroll + for (int p = 0; p < pairsPerTileMember; p++) { + int const c0 = baseCol(t, p * 2); + if (c0 < DSTATE) { + local_max = fmaxf(local_max, fmaxf(fabsf(rState[t][p].x), fabsf(rState[t][p].y))); + } + } + } + // Reduce max across lanesPerRow lanes +#pragma unroll + for (int offset = lanesPerRow / 2; offset >= 1; offset /= 2) { + local_max = fmaxf(local_max, __shfl_down_sync(UINT32_MAX, local_max, offset)); + } + // Broadcast from member 0 of each group + local_max = __shfl_sync(UINT32_MAX, local_max, lane_member_0); + return (local_max == 0.f) ? 1.f + : static_cast(std::numeric_limits::max()) / local_max; +} + } // namespace flashinfer::mamba::mtp diff --git a/tests/mamba/test_selective_state_update_mtp.py b/tests/mamba/test_selective_state_update_mtp.py index e126a83997..7ed0c2adb8 100644 --- a/tests/mamba/test_selective_state_update_mtp.py +++ b/tests/mamba/test_selective_state_update_mtp.py @@ -57,7 +57,6 @@ class TestSelectiveStateUpdateMTP: autouse=True, params=[ "simple", - "async_horizontal", pytest.param("vertical", marks=_requires_sm100), pytest.param("horizontal", marks=_requires_sm100), ], @@ -230,6 +229,68 @@ def test_output_correctness( self.assert_outputs_match(y_ref, y_test, msg_prefix=prefix) +class TestSelectiveStateUpdateMTPPadSlots(TestSelectiveStateUpdateMTP): + """Test that pad slots produce correct output (zero state, valid B/C/x). + + When a batch entry maps to pad_slot_id, the kernel should treat the state as + zero but still load B/C/x and compute y = D * x (since dA * 0 + dB * x with + zero initial state reduces through the recurrence). The state for pad slots + must not be written back. + """ + + PAD_SLOT_ID = 99 + + def make_inputs( + self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ): + inputs = super().make_inputs( + batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ) + # Mark ~25% of batch entries as pad slots + num_pad = max(1, batch // 4) + pad_indices = torch.randperm(batch)[:num_pad] + inputs["slot_idx"][pad_indices] = self.PAD_SLOT_ID + inputs["pad_indices"] = pad_indices + return inputs + + def make_reference_output(self, inputs): + state_ref = clone_preserving_strides(inputs["state_cache"]) + y_ref = selective_state_update_triton( + state_ref, + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=self.PAD_SLOT_ID, + ) + return y_ref, state_ref + + def run_kernel(self, inputs, out=None, disable_state_update=False): + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=self.PAD_SLOT_ID, + out=out, + disable_state_update=disable_state_update, + algorithm=self._algo, + ) + + class TestSelectiveStateUpdateMTPWithZ(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with z tensor (gating).""" @@ -867,7 +928,7 @@ def test_output_correctness( ) # Vertical/horizontal don't support scaled (quantized) state - if self._algo in ("vertical", "horizontal", "async_horizontal"): + if self._algo in ("vertical", "horizontal"): with pytest.raises(RuntimeError, match="does not support scaled"): self.run_kernel(inputs) return @@ -1006,7 +1067,7 @@ def test_output_correctness( ) # Vertical/horizontal don't support scaled (quantized) state - if self._algo in ("vertical", "horizontal", "async_horizontal"): + if self._algo in ("vertical", "horizontal"): with pytest.raises(RuntimeError, match="does not support scaled"): self.run_kernel_with_intermediate_states(inputs) return diff --git a/tests/mamba/test_selective_state_update_varlen.py b/tests/mamba/test_selective_state_update_varlen.py index 620ec48e2b..f1e4f0e682 100644 --- a/tests/mamba/test_selective_state_update_varlen.py +++ b/tests/mamba/test_selective_state_update_varlen.py @@ -88,9 +88,10 @@ class TestSelectiveStateUpdateDstIndices: NGROUPS = 8 STATE_CACHE_SIZE = 256 + @pytest.mark.parametrize("algorithm", ["simple"]) @pytest.mark.parametrize("batch", [1, 4, 32, 64]) - def test_dst_different_from_src(self, batch): - """State is read from src slots and written to disjoint dst slots.""" + def test_dst_different_from_src(self, batch, algorithm): + """State is read from src slots and written to disjoint dst slots (STP path only).""" torch.manual_seed(42) tensors = _make_base_tensors( batch, @@ -145,6 +146,7 @@ def test_dst_different_from_src(self, batch): dst_state_batch_indices=dst_2d, pad_slot_id=PAD_SLOT_ID, out=out_test, + algorithm=algorithm, ) _assert_match(out_ref, out_test, "output", self.ATOL, self.RTOL) @@ -172,9 +174,10 @@ class TestSelectiveStateUpdateDstIndices2D: NGROUPS = 8 STATE_CACHE_SIZE = 256 + @pytest.mark.parametrize("algorithm", ["simple"]) @pytest.mark.parametrize("batch", [1, 16, 64]) - def test_2d_indices_seqlen1(self, batch): - """2D indices with max_seqlen=1 should behave identically to STP.""" + def test_2d_indices_seqlen1(self, batch, algorithm): + """2D indices with max_seqlen=1 should behave identically to STP (STP path only).""" torch.manual_seed(42) tensors = _make_base_tensors( batch, @@ -226,6 +229,7 @@ def test_2d_indices_seqlen1(self, batch): dst_state_batch_indices=dst_indices, pad_slot_id=PAD_SLOT_ID, out=out_test, + algorithm=algorithm, ) _assert_match(out_ref, out_test, "output", self.ATOL, self.RTOL) @@ -251,6 +255,7 @@ class TestSelectiveStateUpdateVarlen: NGROUPS = 8 STATE_CACHE_SIZE = 512 + @pytest.mark.parametrize("algorithm", ["simple"]) @pytest.mark.parametrize( "n_seqs,max_seqlen", [ @@ -263,7 +268,7 @@ class TestSelectiveStateUpdateVarlen: (16, 4), ], ) - def test_varlen_uniform(self, n_seqs, max_seqlen): + def test_varlen_uniform(self, n_seqs, max_seqlen, algorithm): """All sequences have the same length.""" torch.manual_seed(42) total_tokens = n_seqs * max_seqlen @@ -334,6 +339,7 @@ def test_varlen_uniform(self, n_seqs, max_seqlen): num_accepted_tokens=num_accepted, cu_seqlens=cu_seqlens, cache_steps=max_seqlen, + algorithm=algorithm, ) _assert_match(out_ref, out_test, "output", self.ATOL, self.RTOL) @@ -349,8 +355,9 @@ def test_varlen_uniform(self, n_seqs, max_seqlen): self.RTOL, ) + @pytest.mark.parametrize("algorithm", ["simple"]) @pytest.mark.parametrize("n_seqs", [4, 8]) - def test_varlen_variable_lengths(self, n_seqs): + def test_varlen_variable_lengths(self, n_seqs, algorithm): """Sequences have different lengths (padded with PAD_SLOT_ID).""" max_seqlen = 6 torch.manual_seed(42) @@ -431,6 +438,7 @@ def test_varlen_variable_lengths(self, n_seqs): num_accepted_tokens=num_accepted, cu_seqlens=cu_seqlens, cache_steps=max_seqlen, + algorithm=algorithm, ) _assert_match(out_ref, out_test, "output", self.ATOL, self.RTOL) @@ -447,9 +455,12 @@ class TestSelectiveStateUpdateNumAcceptedTokens: NGROUPS = 8 STATE_CACHE_SIZE = 512 + @pytest.mark.parametrize("algorithm", ["simple"]) @pytest.mark.parametrize("n_seqs", [4, 8, 16]) @pytest.mark.parametrize("num_accepted_dtype", [torch.int32, torch.int64]) - def test_num_accepted_selects_initial_state(self, n_seqs, num_accepted_dtype): + def test_num_accepted_selects_initial_state( + self, n_seqs, num_accepted_dtype, algorithm + ): """num_accepted_tokens controls which state slot to read as initial.""" max_seqlen = 4 total_tokens = n_seqs * max_seqlen @@ -521,6 +532,7 @@ def test_num_accepted_selects_initial_state(self, n_seqs, num_accepted_dtype): num_accepted_tokens=num_accepted, cu_seqlens=cu_seqlens, cache_steps=max_seqlen, + algorithm=algorithm, ) _assert_match(out_ref, out_test, "output", self.ATOL, self.RTOL)