Skip to content

[Perf] add cute dsl kernel for gdn decode#36111

Closed
ZJY0516 wants to merge 10 commits intovllm-project:mainfrom
ZJY0516:gdn_decode_cute
Closed

[Perf] add cute dsl kernel for gdn decode#36111
ZJY0516 wants to merge 10 commits intovllm-project:mainfrom
ZJY0516:gdn_decode_cute

Conversation

@ZJY0516
Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 commented Mar 5, 2026

Purpose

add a cutedsl kernel for gdn decode

cc @ywang96 @vadiklyutiy

Test Plan

vllm serve Qwen/Qwen3.5-397B-A17B --language-model-only -tp 8
VLLM_GDN_DECODE_BACKEND=cutedsl vllm serve Qwen/Qwen3.5-397B-A17B --language-model-only -tp 8
vllm bench serve --model Qwen/Qwen3.5-397B-A17B --endpoint /v1/completions --dataset-name random --max-concurrency 32 --random-output-len 1024 --num-prompts 256

Test Result

H20

main

============ Serving Benchmark Result ============
Successful requests:                     256       
Failed requests:                         0         
Maximum request concurrency:             32        
Benchmark duration (s):                  183.30    
Total input tokens:                      262144    
Total generated tokens:                  262144    
Request throughput (req/s):              1.40      
Output token throughput (tok/s):         1430.10   
Peak output token throughput (tok/s):    1696.00   
Peak concurrent requests:                64.00     
Total token throughput (tok/s):          2860.20   
---------------Time to First Token----------------
Mean TTFT (ms):                          962.33    
Median TTFT (ms):                        1069.46   
P99 TTFT (ms):                           1712.90   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          21.45     
Median TPOT (ms):                        21.45     
P99 TPOT (ms):                           22.82     
---------------Inter-token Latency----------------
Mean ITL (ms):                           21.45     
Median ITL (ms):                         20.81     
P99 ITL (ms):                            22.34     
==================================================

cutedsl kernel

============ Serving Benchmark Result ============
Successful requests:                     256       
Failed requests:                         0         
Maximum request concurrency:             32        
Benchmark duration (s):                  180.24    
Total input tokens:                      262144    
Total generated tokens:                  262144    
Request throughput (req/s):              1.42      
Output token throughput (tok/s):         1454.39   
Peak output token throughput (tok/s):    1728.00   
Peak concurrent requests:                64.00     
Total token throughput (tok/s):          2908.78   
---------------Time to First Token----------------
Mean TTFT (ms):                          948.02    
Median TTFT (ms):                        1065.34   
P99 TTFT (ms):                           1785.63   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          21.09     
Median TPOT (ms):                        21.08     
P99 TPOT (ms):                           22.14     
---------------Inter-token Latency----------------
Mean ITL (ms):                           21.09     
Median ITL (ms):                         20.33     
P99 ITL (ms):                            21.61     
==================================================
Micro benchmark for GDN decode ops
"""Micro benchmark for GDN decode ops.

Compare:
1) Triton baseline: fused_sigmoid_gating_delta_rule_update
2) CuTe DSL transpose op: cutedsl_transpose_fused_sigmoid_gated_delta_rule_update

The benchmark follows qwen3_next decode call shape:
- q/k/v: [1, T, H, K] / [1, T, HV, V]
- a/b:   [T, HV]
- state: [POOL, HV, V, K] (k-last storage in vLLM)
- indices: [T], cu_seqlens: [T+1] (all decode tokens are length-1 sequences)
"""

from __future__ import annotations

import argparse
import statistics
from dataclasses import dataclass

import torch

from vllm.model_executor.layers.fla.ops.cutedsl_gdn_transpose import (
    cutedsl_transpose_fused_sigmoid_gated_delta_rule_update,
    is_cutedsl_transpose_gdn_available,
)
from vllm.model_executor.layers.fla.ops.fused_sigmoid_gating import (
    fused_sigmoid_gating_delta_rule_update,
)


@dataclass
class Row:
    t: int
    triton_ms: float
    cutedsl_ms: float
    speedup: float
    out_max_diff: float
    out_mean_diff: float
    state_max_diff: float
    state_mean_diff: float


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Benchmark cutedsl transpose vs triton.")
    parser.add_argument(
        "--t-list",
        type=str,
        default="1,2,4,8,16,32,64,128,256,512,700",
        help="Comma-separated decode token counts per step.",
    )
    parser.add_argument("--num-k-heads", type=int, default=16, help="H for q/k.")
    parser.add_argument("--num-v-heads", type=int, default=64, help="HV for v/state.")
    parser.add_argument("--head-k-dim", type=int, default=128, help="K.")
    parser.add_argument("--head-v-dim", type=int, default=128, help="V.")
    parser.add_argument("--pool-size", type=int, default=2048, help="State pool size (max 2047 for cutedsl).")
    parser.add_argument(
        "--use-cuda-graph",
        action="store_true",
        default=True,
        help="Use CUDA Graph for more accurate timing.",
    )
    parser.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16")
    parser.add_argument(
        "--state-dtype",
        choices=["fp32", "fp16", "bf16"],
        default="bf16",
    )
    parser.add_argument("--warmup", type=int, default=30)
    parser.add_argument("--iters", type=int, default=100)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument(
        "--disable-qk-l2norm",
        action="store_true",
        help="Disable qk L2Norm in both kernels.",
    )
    return parser.parse_args()


def _dtype(name: str) -> torch.dtype:
    if name == "bf16":
        return torch.bfloat16
    if name == "fp16":
        return torch.float16
    if name == "fp32":
        return torch.float32
    raise ValueError(f"Unsupported dtype: {name}")


def _time_cuda_ms(fn, warmup: int, iters: int, use_cuda_graph: bool = False) -> float:
    if use_cuda_graph:
        return _time_cuda_graph(fn, warmup, iters)
    
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
    ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
    for i in range(iters):
        starts[i].record()
        fn()
        ends[i].record()
    torch.cuda.synchronize()
    lat = [starts[i].elapsed_time(ends[i]) for i in range(iters)]
    return statistics.mean(lat)


def _time_cuda_graph(fn, warmup: int, iters: int) -> float:
    """Time a function using CUDA Graph replay."""
    # Warmup
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    # Capture graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        fn()
    torch.cuda.synchronize()

    # Time graph replay
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    for _ in range(iters):
        graph.replay()
    end.record()
    torch.cuda.synchronize()
    
    return start.elapsed_time(end) / iters


def _run_triton(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    a: torch.Tensor,
    b: torch.Tensor,
    a_log: torch.Tensor,
    dt_bias: torch.Tensor,
    state: torch.Tensor,
    indices: torch.Tensor,
    cu_seqlens: torch.Tensor,
    use_qk_l2norm: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
    out, final_state = fused_sigmoid_gating_delta_rule_update(
        A_log=a_log,
        a=a,
        b=b,
        dt_bias=dt_bias,
        q=q,
        k=k,
        v=v,
        initial_state=state,
        inplace_final_state=True,
        cu_seqlens=cu_seqlens,
        ssm_state_indices=indices,
        use_qk_l2norm_in_kernel=use_qk_l2norm,
    )
    return out.squeeze(0), final_state


def _run_cutedsl(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    a: torch.Tensor,
    b: torch.Tensor,
    a_log: torch.Tensor,
    dt_bias: torch.Tensor,
    state: torch.Tensor,
    indices: torch.Tensor,
    cu_seqlens: torch.Tensor,
    use_qk_l2norm: bool,
) -> torch.Tensor:
    return cutedsl_transpose_fused_sigmoid_gated_delta_rule_update(
        A_log=a_log,
        a=a,
        dt_bias=dt_bias,
        softplus_beta=1.0,
        softplus_threshold=20.0,
        q=q,
        k=k,
        v=v,
        b=b,
        initial_state_source=state.transpose(-2, -1),
        initial_state_indices=indices,
        use_qk_l2norm_in_kernel=use_qk_l2norm,
        cu_seqlens=cu_seqlens,
    )


def bench_one_t(
    t: int,
    *,
    h: int,
    hv: int,
    kdim: int,
    vdim: int,
    pool_size: int,
    act_dtype: torch.dtype,
    state_dtype: torch.dtype,
    warmup: int,
    iters: int,
    use_qk_l2norm: bool,
    args: argparse.Namespace,
) -> Row:
    device = "cuda"
    q = torch.randn(1, t, h, kdim, device=device, dtype=act_dtype)
    k = torch.randn(1, t, h, kdim, device=device, dtype=act_dtype)
    v = torch.randn(1, t, hv, vdim, device=device, dtype=act_dtype)
    a = torch.randn(t, hv, device=device, dtype=act_dtype)
    b = torch.randn(t, hv, device=device, dtype=act_dtype)
    a_log = torch.randn(hv, device=device, dtype=torch.float32)
    dt_bias = torch.randn(hv, device=device, dtype=torch.float32)

    # Use unique state slots (use int64 indices to avoid CuTe DSL 2048 limitation)
    if pool_size < t:
        raise ValueError(f"pool_size ({pool_size}) must be >= T ({t})")
    indices = torch.randperm(pool_size, device=device, dtype=torch.int64)[:t]
    cu_seqlens = torch.arange(0, t + 1, device=device, dtype=torch.int32)

    state0 = torch.randn(pool_size, hv, vdim, kdim, device=device, dtype=state_dtype)

    # Correctness snapshot
    state_tri_ref = state0.clone()
    state_cut_ref = state0.clone()
    out_tri_ref, _ = _run_triton(
        q,
        k,
        v,
        a,
        b,
        a_log,
        dt_bias,
        state_tri_ref,
        indices,
        cu_seqlens,
        use_qk_l2norm,
    )
    out_cut_ref = _run_cutedsl(
        q,
        k,
        v,
        a,
        b,
        a_log,
        dt_bias,
        state_cut_ref,
        indices,
        cu_seqlens,
        use_qk_l2norm,
    )
    out_diff = (out_tri_ref.float() - out_cut_ref.float()).abs()
    state_diff = (state_tri_ref.float() - state_cut_ref.float()).abs()

    # Timing
    state_tri = state0.clone()
    state_cut = state0.clone()

    triton_ms = _time_cuda_ms(
        lambda: _run_triton(
            q,
            k,
            v,
            a,
            b,
            a_log,
            dt_bias,
            state_tri,
            indices,
            cu_seqlens,
            use_qk_l2norm,
        ),
        warmup=warmup,
        iters=iters,
        use_cuda_graph=args.use_cuda_graph,
    )
    cutedsl_ms = _time_cuda_ms(
        lambda: _run_cutedsl(
            q,
            k,
            v,
            a,
            b,
            a_log,
            dt_bias,
            state_cut,
            indices,
            cu_seqlens,
            use_qk_l2norm,
        ),
        warmup=warmup,
        iters=iters,
        use_cuda_graph=args.use_cuda_graph,
    )

    return Row(
        t=t,
        triton_ms=triton_ms,
        cutedsl_ms=cutedsl_ms,
        speedup=(triton_ms / cutedsl_ms) if cutedsl_ms > 0 else float("inf"),
        out_max_diff=out_diff.max().item(),
        out_mean_diff=out_diff.mean().item(),
        state_max_diff=state_diff.max().item(),
        state_mean_diff=state_diff.mean().item(),
    )


def main() -> None:
    args = parse_args()

    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for this benchmark.")
    if not is_cutedsl_transpose_gdn_available():
        raise RuntimeError("CuTe DSL transpose kernel is unavailable in current env.")
    if args.head_k_dim != args.head_v_dim or args.head_k_dim not in (128, 256):
        raise ValueError(
            "cutedsl_transpose benchmark requires head_k_dim == head_v_dim in {128, 256}."
        )

    t_list = [int(x) for x in args.t_list.split(",") if x.strip()]
    if not t_list:
        raise ValueError("--t-list cannot be empty")

    act_dtype = _dtype(args.dtype)
    state_dtype = _dtype(args.state_dtype)
    use_qk_l2norm = not args.disable_qk_l2norm

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    # Clear compile cache
    cutedsl_transpose_fused_sigmoid_gated_delta_rule_update.compile_cache.clear()

    graph_mode = "CUDA Graph" if args.use_cuda_graph else "Eager"
    print(
        f"Running benchmark ({graph_mode}) with "
        f"T={t_list}, H={args.num_k_heads}, HV={args.num_v_heads}, "
        f"K={args.head_k_dim}, V={args.head_v_dim}, "
        f"dtype={act_dtype}, state_dtype={state_dtype}, "
        f"warmup={args.warmup}, iters={args.iters}, "
        f"use_qk_l2norm={use_qk_l2norm}"
    )
    print(
        "\n"
        "      T | triton(ms) | cutedsl(ms) | speedup | "
        "out_max_diff | out_mean_diff | state_max_diff | state_mean_diff"
    )
    print("-" * 118)

    rows: list[Row] = []
    for t in t_list:
        row = bench_one_t(
            t,
            h=args.num_k_heads,
            hv=args.num_v_heads,
            kdim=args.head_k_dim,
            vdim=args.head_v_dim,
            pool_size=args.pool_size,
            act_dtype=act_dtype,
            state_dtype=state_dtype,
            warmup=args.warmup,
            iters=args.iters,
            use_qk_l2norm=use_qk_l2norm,
            args=args,
        )
        rows.append(row)
        print(
            f"{row.t:7d} | "
            f"{row.triton_ms:10.4f} | "
            f"{row.cutedsl_ms:11.4f} | "
            f"{row.speedup:7.3f} | "
            f"{row.out_max_diff:12.5e} | "
            f"{row.out_mean_diff:13.5e} | "
            f"{row.state_max_diff:14.5e} | "
            f"{row.state_mean_diff:15.5e}"
        )


if __name__ == "__main__":
    main()

      T | triton(ms) | cutedsl(ms) | speedup | out_max_diff | out_mean_diff | state_max_diff | state_mean_diff
----------------------------------------------------------------------------------------------------------------------
      1 |     0.0045 |      0.0044 |   1.034 |  1.90735e-06 |   2.32831e-10 |    3.90625e-03 |     5.26248e-12
      2 |     0.0067 |      0.0054 |   1.233 |  1.52588e-05 |   9.61336e-10 |    3.90625e-03 |     8.32570e-12
      4 |     0.0098 |      0.0081 |   1.216 |  2.44141e-04 |   7.46149e-09 |    3.90625e-03 |     7.80790e-12
      8 |     0.0170 |      0.0131 |   1.297 |  1.22070e-04 |   1.92406e-09 |    7.81250e-03 |     2.14891e-11
     16 |     0.0341 |      0.0240 |   1.422 |  4.88281e-04 |   6.01012e-09 |    7.81250e-03 |     6.32187e-11
     32 |     0.0819 |      0.0496 |   1.650 |  4.88281e-04 |   5.73225e-09 |    7.81250e-03 |     1.43033e-10
     64 |     0.1599 |      0.0982 |   1.628 |  4.88281e-04 |   4.51996e-09 |    7.81250e-03 |     2.44097e-10
    128 |     0.3160 |      0.1937 |   1.631 |  4.88281e-04 |   3.38923e-09 |    1.56250e-02 |     5.52435e-10
    256 |     0.6272 |      0.3824 |   1.640 |  9.76562e-04 |   4.46810e-09 |    1.56250e-02 |     9.70961e-10
    512 |     1.2489 |      0.7773 |   1.607 |  9.76562e-04 |   3.34901e-09 |    1.56250e-02 |     2.00526e-09
    700 |     1.7023 |      1.0761 |   1.582 |  4.88281e-04 |   3.15132e-09 |    1.56250e-02 |     2.55948e-09

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@mergify mergify bot added the qwen Related to Qwen models label Mar 5, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new CuTe DSL kernel for GDN decode in Qwen3Next models, aimed at improving performance. The changes are gated by a new environment variable VLLM_GDN_DECODE_BACKEND. My review focuses on the implementation and integration of this new kernel. I've identified a potential performance issue on Hopper (SM90) GPUs where a slower, scalar code path might be taken instead of the optimized intrinsic-based path. This could undermine the performance goals of this PR for a key architecture.

@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Mar 5, 2026

We'd better wait for #35777 merged

@vadiklyutiy vadiklyutiy moved this to In progress in Qwen3.5 Mar 5, 2026
ZJY0516 added 4 commits March 6, 2026 10:29
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516 ZJY0516 marked this pull request as ready for review March 9, 2026 10:42
@ZJY0516 ZJY0516 requested a review from sighingnow as a code owner March 9, 2026 10:42
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Mar 9, 2026

E2E measurement is much low than micro bennchmark. Did you measure the same flashsinfer version?

Because the GDN decode kernel accounts for a very small portion of the end-to-end runtime.

FlashInfer does not accept non-contiguous states(vllm uses non-contiguous), after modifying it locally, it's faster than this kernel.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

FlashInfer has optimized version of GDN. The last PR in FI landed several days ago and in v0.6.6 will be available fully optimized version.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

vadiklyutiy commented Mar 9, 2026

Because the GDN decode kernel accounts for a very small portion of the end-to-end runtime.

When we come to decode phase, GDN consume around 35% for Qwen3.5-397B... Not sure why 35B is so small...

@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Mar 9, 2026

Because the GDN decode kernel accounts for a very small portion of the end-to-end runtime.

When we come to decode phase, GDN consume around 35% for Qwen3.5-397B... Not sure why 35B is so small...

Let me test qwen3.5 397B

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Mar 9, 2026

@vadiklyutiy I have updated the perf data for qwen 3.5 397B

@ywang96 ywang96 moved this from In progress to In review in Qwen3.5 Mar 9, 2026
@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Mar 10, 2026

Because the GDN decode kernel accounts for a very small portion of the end-to-end runtime.

When we come to decode phase, GDN consume around 35% for Qwen3.5-397B... Not sure why 35B is so small...

which senario? I can not reprduce this

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

Because the GDN decode kernel accounts for a very small portion of the end-to-end runtime.

When we come to decode phase, GDN consume around 35% for Qwen3.5-397B... Not sure why 35B is so small...

which senario? I can not reprduce this

B200
-dp 8
max-concurrency=2048, ISL/OSL=2/500

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

FlashInfer does not accept non-contiguous states(vllm uses non-contiguous),

Could you clarify what dim isn't contiguous?

after modifying it locally, it's faster than this kernel.

Did you modify flashinfer or vllm? If flashinfer faster maybe it is worth to use flashinfer?

@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Mar 11, 2026

comparison with flashinfer

Micro benchmark for GDN decode ops
#!/usr/bin/env python3
"""Benchmark CuTeDSL transpose GDN op vs FlashInfer GDN op in CUDA Graph.

Compared ops:
1) vLLM CuTeDSL op:
   cutedsl_transpose_fused_sigmoid_gated_delta_rule_update
2) FlashInfer op:
   flashinfer.gdn_decode.gated_delta_rule_decode_pretranspose

The benchmark uses CUDA Graph capture/replay for both ops.
"""

from __future__ import annotations

import argparse
import statistics
from dataclasses import dataclass
from typing import Callable

import torch

from vllm.model_executor.layers.fla.ops.cutedsl_gdn_transpose import (
    cutedsl_transpose_fused_sigmoid_gated_delta_rule_update,
)

try:
    from flashinfer.gdn_decode import gated_delta_rule_decode_pretranspose
except Exception as e:  # pragma: no cover - runtime env dependent
    gated_delta_rule_decode_pretranspose = None
    _FLASHINFER_IMPORT_ERROR = e
else:
    _FLASHINFER_IMPORT_ERROR = None


@dataclass
class BenchRow:
    batch_size: int
    cutedsl_ms: float
    flashinfer_ms: float
    cutedsl_tps: float
    flashinfer_tps: float
    speedup_flashinfer_vs_cutedsl: float
    out_max_diff: float
    out_mean_diff: float
    state_max_diff: float
    state_mean_diff: float


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Benchmark cutedsl transpose GDN vs FlashInfer GDN under CUDA Graph"
    )
    parser.add_argument(
        "--batch-sizes",
        type=str,
        default="1,4,8,16,32,64,128,256,512",
        help="Comma-separated batch sizes (decode tokens per step).",
    )
    parser.add_argument("--num-k-heads", type=int, default=16, help="H for q/k")
    parser.add_argument("--num-v-heads", type=int, default=64, help="HV for v/state")
    parser.add_argument("--head-k-dim", type=int, default=128, help="K for q/k")
    parser.add_argument("--head-v-dim", type=int, default=128, help="V for v/state")
    parser.add_argument("--pool-size", type=int, default=4096, help="state pool size")
    parser.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16")
    parser.add_argument(
        "--state-dtype",
        choices=["bf16", "fp32"],
        default="bf16",
        help="State dtype used by both kernels.",
    )
    parser.add_argument("--warmup", type=int, default=30)
    parser.add_argument("--iters", type=int, default=200)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument(
        "--disable-qk-l2norm",
        action="store_true",
        help="Disable qk L2Norm in both kernels.",
    )
    parser.add_argument(
        "--skip-correctness",
        action="store_true",
        help="Skip correctness check (output/state diffs).",
    )
    return parser.parse_args()


def _dtype(name: str) -> torch.dtype:
    if name == "bf16":
        return torch.bfloat16
    if name == "fp16":
        return torch.float16
    if name == "fp32":
        return torch.float32
    raise ValueError(f"Unsupported dtype: {name}")


def _capture_graph(fn: Callable[[], None]) -> torch.cuda.CUDAGraph:
    graph = torch.cuda.CUDAGraph()
    torch.cuda.synchronize()
    with torch.cuda.graph(graph):
        fn()
    torch.cuda.synchronize()
    return graph


def _time_graph_ms(graph: torch.cuda.CUDAGraph, warmup: int, iters: int) -> float:
    for _ in range(warmup):
        graph.replay()
    torch.cuda.synchronize()

    starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
    ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
    for i in range(iters):
        starts[i].record()
        graph.replay()
        ends[i].record()
    torch.cuda.synchronize()

    lat = [starts[i].elapsed_time(ends[i]) for i in range(iters)]
    return statistics.mean(lat)


def _make_inputs(
    batch_size: int,
    h: int,
    hv: int,
    kdim: int,
    vdim: int,
    pool_size: int,
    act_dtype: torch.dtype,
    state_dtype: torch.dtype,
):
    device = "cuda"
    q_cut = torch.randn(1, batch_size, h, kdim, device=device, dtype=act_dtype)
    k_cut = torch.randn(1, batch_size, h, kdim, device=device, dtype=act_dtype)
    v_cut = torch.randn(1, batch_size, hv, vdim, device=device, dtype=act_dtype)

    # cutedsl accepts [T, HV] or [1, T, HV]; use 2D to match qwen3_next usage.
    a_cut = torch.randn(batch_size, hv, device=device, dtype=act_dtype)
    b_cut = torch.randn(batch_size, hv, device=device, dtype=act_dtype)

    a_log = torch.randn(hv, device=device, dtype=torch.float32)
    dt_bias = torch.randn(hv, device=device, dtype=torch.float32)

    if pool_size < batch_size:
        raise ValueError(f"pool_size ({pool_size}) must be >= batch_size ({batch_size})")

    # Use unique slots for direct state-diff check.
    indices = torch.randperm(pool_size, device=device, dtype=torch.int64)[:batch_size]
    cu_seqlens = torch.arange(0, batch_size + 1, device=device, dtype=torch.int32)

    state_init = torch.randn(pool_size, hv, vdim, kdim, device=device, dtype=state_dtype)

    return {
        "q_cut": q_cut,
        "k_cut": k_cut,
        "v_cut": v_cut,
        "a_cut": a_cut,
        "b_cut": b_cut,
        "a_log": a_log,
        "dt_bias": dt_bias,
        "indices": indices,
        "cu_seqlens": cu_seqlens,
        "state_init": state_init,
    }


def _run_cutedsl(
    *,
    q_cut: torch.Tensor,
    k_cut: torch.Tensor,
    v_cut: torch.Tensor,
    a_cut: torch.Tensor,
    b_cut: torch.Tensor,
    a_log: torch.Tensor,
    dt_bias: torch.Tensor,
    state_pool: torch.Tensor,
    indices: torch.Tensor,
    cu_seqlens: torch.Tensor,
    use_qk_l2norm: bool,
) -> torch.Tensor:
    # Kernel expects K-contiguous view, i.e. [N, HV, K, V].
    state_k_major = state_pool.transpose(-2, -1)
    return cutedsl_transpose_fused_sigmoid_gated_delta_rule_update(
        A_log=a_log,
        a=a_cut,
        dt_bias=dt_bias,
        softplus_beta=1.0,
        softplus_threshold=20.0,
        q=q_cut,
        k=k_cut,
        v=v_cut,
        b=b_cut,
        initial_state_source=state_k_major,
        initial_state_indices=indices,
        use_qk_l2norm_in_kernel=use_qk_l2norm,
        cu_seqlens=cu_seqlens,
    )


def _run_flashinfer(
    *,
    q_fi: torch.Tensor,
    k_fi: torch.Tensor,
    v_fi: torch.Tensor,
    a_fi: torch.Tensor,
    b_fi: torch.Tensor,
    a_log: torch.Tensor,
    dt_bias: torch.Tensor,
    state_pool: torch.Tensor,
    indices: torch.Tensor,
    use_qk_l2norm: bool,
    output_fi: torch.Tensor,
) -> torch.Tensor:
    out, _ = gated_delta_rule_decode_pretranspose(
        q=q_fi,
        k=k_fi,
        v=v_fi,
        state=None,
        A_log=a_log,
        a=a_fi,
        dt_bias=dt_bias,
        b=b_fi,
        output=output_fi,
        use_qk_l2norm=use_qk_l2norm,
        initial_state=state_pool,
        initial_state_indices=indices,
    )
    return out.squeeze(1)


def _check_flashinfer_import() -> None:
    if gated_delta_rule_decode_pretranspose is None:
        raise RuntimeError(f"flashinfer import failed: {_FLASHINFER_IMPORT_ERROR!r}")


def bench_one_batch(
    *,
    batch_size: int,
    h: int,
    hv: int,
    kdim: int,
    vdim: int,
    pool_size: int,
    act_dtype: torch.dtype,
    state_dtype: torch.dtype,
    warmup: int,
    iters: int,
    use_qk_l2norm: bool,
    run_correctness: bool,
) -> BenchRow:
    data = _make_inputs(
        batch_size=batch_size,
        h=h,
        hv=hv,
        kdim=kdim,
        vdim=vdim,
        pool_size=pool_size,
        act_dtype=act_dtype,
        state_dtype=state_dtype,
    )

    # Shared input tensors for FlashInfer shape.
    q_fi = data["q_cut"].transpose(0, 1).contiguous()
    k_fi = data["k_cut"].transpose(0, 1).contiguous()
    v_fi = data["v_cut"].transpose(0, 1).contiguous()
    a_fi = data["a_cut"].unsqueeze(1).contiguous()
    b_fi = data["b_cut"].unsqueeze(1).contiguous()

    out_max_diff = float("nan")
    out_mean_diff = float("nan")
    state_max_diff = float("nan")
    state_mean_diff = float("nan")

    if run_correctness:
        state_cut_chk = data["state_init"].clone()
        state_fi_chk = data["state_init"].clone()
        output_fi_chk = torch.empty(
            (batch_size, 1, hv, vdim),
            device="cuda",
            dtype=act_dtype,
        )

        out_cut = _run_cutedsl(
            q_cut=data["q_cut"],
            k_cut=data["k_cut"],
            v_cut=data["v_cut"],
            a_cut=data["a_cut"],
            b_cut=data["b_cut"],
            a_log=data["a_log"],
            dt_bias=data["dt_bias"],
            state_pool=state_cut_chk,
            indices=data["indices"],
            cu_seqlens=data["cu_seqlens"],
            use_qk_l2norm=use_qk_l2norm,
        )
        out_fi = _run_flashinfer(
            q_fi=q_fi,
            k_fi=k_fi,
            v_fi=v_fi,
            a_fi=a_fi,
            b_fi=b_fi,
            a_log=data["a_log"],
            dt_bias=data["dt_bias"],
            state_pool=state_fi_chk,
            indices=data["indices"],
            use_qk_l2norm=use_qk_l2norm,
            output_fi=output_fi_chk,
        )

        out_diff = (out_cut.float() - out_fi.float()).abs()
        state_diff = (state_cut_chk.float() - state_fi_chk.float()).abs()
        out_max_diff = out_diff.max().item()
        out_mean_diff = out_diff.mean().item()
        state_max_diff = state_diff.max().item()
        state_mean_diff = state_diff.mean().item()

    # Warm up for JIT/compile cache.
    state_cut = data["state_init"].clone()
    state_fi = data["state_init"].clone()
    output_fi = torch.empty((batch_size, 1, hv, vdim), device="cuda", dtype=act_dtype)

    def cutedsl_call() -> None:
        _ = _run_cutedsl(
            q_cut=data["q_cut"],
            k_cut=data["k_cut"],
            v_cut=data["v_cut"],
            a_cut=data["a_cut"],
            b_cut=data["b_cut"],
            a_log=data["a_log"],
            dt_bias=data["dt_bias"],
            state_pool=state_cut,
            indices=data["indices"],
            cu_seqlens=data["cu_seqlens"],
            use_qk_l2norm=use_qk_l2norm,
        )

    def flashinfer_call() -> None:
        _ = _run_flashinfer(
            q_fi=q_fi,
            k_fi=k_fi,
            v_fi=v_fi,
            a_fi=a_fi,
            b_fi=b_fi,
            a_log=data["a_log"],
            dt_bias=data["dt_bias"],
            state_pool=state_fi,
            indices=data["indices"],
            use_qk_l2norm=use_qk_l2norm,
            output_fi=output_fi,
        )

    # Eager warmup to trigger compile/JIT before capture.
    prewarm = max(3, min(10, warmup))
    for _ in range(prewarm):
        cutedsl_call()
    for _ in range(prewarm):
        flashinfer_call()
    torch.cuda.synchronize()

    cutedsl_graph = _capture_graph(cutedsl_call)
    flashinfer_graph = _capture_graph(flashinfer_call)

    cutedsl_ms = _time_graph_ms(cutedsl_graph, warmup=warmup, iters=iters)
    flashinfer_ms = _time_graph_ms(flashinfer_graph, warmup=warmup, iters=iters)

    cutedsl_tps = batch_size * 1000.0 / cutedsl_ms
    flashinfer_tps = batch_size * 1000.0 / flashinfer_ms

    return BenchRow(
        batch_size=batch_size,
        cutedsl_ms=cutedsl_ms,
        flashinfer_ms=flashinfer_ms,
        cutedsl_tps=cutedsl_tps,
        flashinfer_tps=flashinfer_tps,
        speedup_flashinfer_vs_cutedsl=cutedsl_ms / flashinfer_ms,
        out_max_diff=out_max_diff,
        out_mean_diff=out_mean_diff,
        state_max_diff=state_max_diff,
        state_mean_diff=state_mean_diff,
    )


def main() -> None:
    args = parse_args()
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for this benchmark.")

    _check_flashinfer_import()

    if args.head_k_dim != args.head_v_dim or args.head_k_dim not in (128, 256):
        raise ValueError("Current cutedsl op requires K == V and K in {128, 256}.")

    batch_sizes = [int(x) for x in args.batch_sizes.split(",") if x.strip()]
    if not batch_sizes:
        raise ValueError("--batch-sizes cannot be empty")

    if max(batch_sizes) > args.pool_size:
        raise ValueError("pool-size must be >= max(batch-sizes)")

    act_dtype = _dtype(args.dtype)
    state_dtype = _dtype(args.state_dtype)
    use_qk_l2norm = not args.disable_qk_l2norm

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Reduce one-time noise from first CUDA use.
    _ = torch.empty(1, device="cuda")
    torch.cuda.synchronize()

    # Ensure fair first compile per shape.
    cutedsl_transpose_fused_sigmoid_gated_delta_rule_update.compile_cache.clear()

    print("=== CuTeDSL vs FlashInfer (CUDA Graph) ===")
    print(
        f"batch_sizes={batch_sizes} H={args.num_k_heads} HV={args.num_v_heads} "
        f"K={args.head_k_dim} V={args.head_v_dim} "
        f"dtype={act_dtype} state_dtype={state_dtype} "
        f"warmup={args.warmup} iters={args.iters} "
        f"use_qk_l2norm={use_qk_l2norm}"
    )
    print(
        "\n"
        "   B | cutedsl(ms) | flashinfer(ms) | fi/cute speedup | "
        "cutedsl tok/s | flashinfer tok/s | out_max_diff | state_max_diff"
    )
    print("-" * 126)

    rows: list[BenchRow] = []
    with torch.inference_mode():
        for batch_size in batch_sizes:
            row = bench_one_batch(
                batch_size=batch_size,
                h=args.num_k_heads,
                hv=args.num_v_heads,
                kdim=args.head_k_dim,
                vdim=args.head_v_dim,
                pool_size=args.pool_size,
                act_dtype=act_dtype,
                state_dtype=state_dtype,
                warmup=args.warmup,
                iters=args.iters,
                use_qk_l2norm=use_qk_l2norm,
                run_correctness=not args.skip_correctness,
            )
            rows.append(row)
            print(
                f"{row.batch_size:4d} | "
                f"{row.cutedsl_ms:11.4f} | "
                f"{row.flashinfer_ms:13.4f} | "
                f"{row.speedup_flashinfer_vs_cutedsl:14.3f}x | "
                f"{row.cutedsl_tps:13.1f} | "
                f"{row.flashinfer_tps:15.1f} | "
                f"{row.out_max_diff:11.4e} | "
                f"{row.state_max_diff:13.4e}"
            )

    fi_speedups = [r.speedup_flashinfer_vs_cutedsl for r in rows]
    print("\nSummary:")
    print(
        f"  mean(fi/cute speedup)={statistics.mean(fi_speedups):.3f}x, "
        f"best={max(fi_speedups):.3f}x, worst={min(fi_speedups):.3f}x"
    )


if __name__ == "__main__":
    main()

H20

python bench_cutedsl_transpose_vs_flashinfer_cudagraph.py \
--batch-sizes 32 \
--num-k-heads 16 --num-v-heads 64 \
--head-k-dim 128 --head-v-dim 128 \
--pool-size 256 \
--warmup 2 --iters 5000
=== CuTeDSL vs FlashInfer (CUDA Graph) ===
batch_sizes=[32] H=16 HV=64 K=128 V=128 dtype=torch.bfloat16 state_dtype=torch.bfloat16 warmup=2 iters=5000 use_qk_l2norm=True

   B | cutedsl(ms) | flashinfer(ms) | fi/cute speedup | cutedsl tok/s | flashinfer tok/s | out_max_diff | state_max_diff
------------------------------------------------------------------------------------------------------------------------------
  32 |      0.0513 |        0.0520 |          0.986x |      623666.8 |        614818.4 |  2.4414e-04 |    7.8125e-03

Summary:
  mean(fi/cute speedup)=0.986x, best=0.986x, worst=0.986x

H200

=== CuTeDSL vs FlashInfer (CUDA Graph) ===
batch_sizes=[32] H=16 HV=64 K=128 V=128 dtype=torch.bfloat16 state_dtype=torch.bfloat16 warmup=2 iters=5000 use_qk_l2norm=True

   B | cutedsl(ms) | flashinfer(ms) | fi/cute speedup | cutedsl tok/s | flashinfer tok/s | out_max_diff | state_max_diff
------------------------------------------------------------------------------------------------------------------------------
  32 |      0.0557 |        0.0619 |          0.900x |      574844.5 |        517324.9 |  4.8828e-04 |    1.5625e-02

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

what about another batch sizes?

@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Mar 11, 2026

what about another batch sizes?

   B | cutedsl(ms) | flashinfer(ms) | fi/cute speedup | cutedsl tok/s | flashinfer tok/s | out_max_diff | state_max_diff
------------------------------------------------------------------------------------------------------------------------------
   1 |      0.0100 |        0.0109 |          0.921x |       99993.6 |         92123.1 |  3.8147e-06 |    7.8125e-03
   4 |      0.0115 |        0.0150 |          0.768x |      347048.7 |        266479.1 |  6.1035e-05 |    3.9062e-03
   8 |      0.0162 |        0.0187 |          0.865x |      493719.9 |        427058.4 |  1.2207e-04 |    7.8125e-03
  16 |      0.0275 |        0.0280 |          0.980x |      582004.4 |        570567.8 |  1.2207e-04 |    7.8125e-03
  32 |      0.0516 |        0.0530 |          0.975x |      619947.4 |        604295.3 |  2.4414e-04 |    7.8125e-03
  64 |      0.0977 |        0.0929 |          1.052x |      654806.0 |        688947.2 |  4.8828e-04 |    1.5625e-02
 128 |      0.1915 |        0.1722 |          1.112x |      668529.6 |        743290.0 |  9.7656e-04 |    1.5625e-02
 256 |      0.3810 |        0.3344 |          1.139x |      671906.4 |        765588.3 |  4.8828e-04 |    1.5625e-02
 512 |      0.7655 |        0.6720 |          1.139x |      668828.7 |        761961.4 |  4.8828e-04 |    1.5625e-02

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

Regarding E2E perf

I collected profile

nsys profile --trace=cuda,nvtx --cuda-graph-trace=node --trace-fork-before-exec=true --capture-range=cudaProfilerApi vllm serve Qwen/Qwen3.5-35B-A3B-FP8 --language-model-only --reasoning-parser qwen3 --kv-cache-dtype fp8 --stream-interval=100 --profiler-config '{"profiler":"cuda","max_iterations":32,"delay_iterations":200}'
vllm bench serve --backend vllm --model Qwen/Qwen3.5-35B-A3B-FP8 --endpoint /v1/completions --dataset-name random --random-input 2 --random-output 500 --max-concurrency 256 --num-prompt 512 --ignore-eos --profile --temperature 0.0

and see that gdn takes 40%

image

@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Mar 11, 2026

Not so much on my machine

                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
 -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                execute_context_0(0)_generation_256(256)         0.00%       0.000us         0.00%       0.000us       0.000us      105.975s       185.43%      105.975s      53.469ms          1982  
                                        fused_moe_kernel         0.00%       0.000us         0.00%       0.000us       0.000us       29.467s        51.56%       29.467s     365.775us         80560  
           fused_sigmoid_gating_delta_rule_update_kernel         0.00%       0.000us         0.00%       0.000us       0.000us       12.108s        21.19%       12.108s     403.999us         29970  
               nvjet_tst_320x128_64x3_1x2_h_bz_coopB_TNT         0.00%       0.000us         0.00%       0.000us       0.000us        4.735s         8.29%        4.735s     153.820us         30785  
                                                aten::mm         0.09%      49.640ms         0.26%     148.208ms     147.178us        1.852s         3.24%        1.853s       1.840ms          1007  
 void cutlass::device_kernel<flash::enable_sm90_or_la...         0.00%       0.000us         0.00%       0.000us       0.000us        1.672s         2.93%        1.672s     167.355us          9990  

12.108s / 105.975s = 8%

@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Mar 11, 2026

using cutedsl kernel

                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
 -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                execute_context_0(0)_generation_256(256)         0.00%       0.000us         0.00%       0.000us       0.000us       96.991s       179.39%       96.991s      50.202ms          1932  
                                        fused_moe_kernel         0.00%       0.000us         0.00%       0.000us       0.000us       29.561s        54.68%       29.561s     359.096us         82320  
 kernel_cutlass_fused_recurrent_sigmoid_update_kernel...         0.00%       0.000us         0.00%       0.000us       0.000us        8.617s        15.94%        8.617s     288.981us         29820  
               nvjet_tst_320x128_64x3_1x2_h_bz_coopB_TNT         0.00%       0.000us         0.00%       0.000us       0.000us        4.757s         8.80%        4.757s     154.204us         30847  
                                                aten::mm         0.10%      52.473ms         0.43%     230.104ms     223.619us        1.861s         3.44%        1.862s       1.809ms          1029  
 void cutlass::device_kernel<flash::enable_sm90_or_la...         0.00%       0.000us         0.00%       0.000us       0.000us        1.678s         3.10%        1.678s     164.534us         10200  

from 12.108s to 8.617s. This is in line with the microbenchmark performance.

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZJY0516.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 12, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@vadiklyutiy
Copy link
Copy Markdown
Collaborator

I ran on B200

CUDA_LAUNCH_BLOCKING=1 VLLM_GDN_DECODE_BACKEND=cutedsl vllm serve Qwen/Qwen3.5-35B-A3B-FP8 --language-model-only --reasoning-parser qwen3 --kv-cache-dtype fp8 --enforce-eager

and after several times of

vllm bench serve --backend vllm --model Qwen/Qwen3.5-35B-A3B-FP8 --endpoint /v1/completions --dataset-name random --random-input 2 --random-output 500 --max-concurrency 256 --num-prompt 512 --ignore-eos --temperature=0.0

got

(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113]     return executor.run_compiled_program(exe_args)
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113]   File "/home/scratch.vgimpelson_ent/venv_b/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/jit_executor.py", line 801, in run_compiled_program
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113]     raise e
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113]   File "/home/scratch.vgimpelson_ent/venv_b/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/jit_executor.py", line 799, in run_compiled_program
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113]     raise DSLCudaRuntimeError(error_code, error_name)
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113] cutlass.base_dsl.common.DSLCudaRuntimeError: DSLCudaRuntimeError: CUDA_ERROR_ILLEGAL_ADDRESS (error code: 700) 
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113]  
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113] 
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113] �[0m
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113] �[1mError Code:�[0m 700
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113] 
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113] �[94m🔍 Additional Context:�[0m 
(EngineCore_DP0 pid=72333) ERROR 03-12 18:41:25 [core.py:1113] - �[1mError name: CUDA_ERROR_ILLEGAL_ADDRESS

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 12, 2026

Hi @ZJY0516, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@ameynaik-hub
Copy link
Copy Markdown

@ZJY0516 ; #36111 (comment)
can you pls clarify if you are running with state (h) precision of bf16 or fp32 because the comment here says it is bf16.

batch_sizes=[32] H=16 HV=64 K=128 V=128 dtype=torch.bfloat16 state_dtype=torch.bfloat16 warmup=2 iters=5000 use_qk_l2norm=True

also I am assuming this is STP (single token prediction) ?

@ameynaik-hub
Copy link
Copy Markdown

fp32 state T=1 gdn decode benchmark (B200, Qwen 3.5)

bench cmd:
python benchmarks/bench_gdn_decode.py
--num-q-heads 16 --num-k-heads 16 --num-v-heads 64 --head-size 128
--version pretranspose
--warmup 10 --iters 1000
--batch-size 1 2 4 8 16 32 64 128 256 512

FI-PreTr (FlashInfer Pretranspose, FP32 State) — Memory Bandwidth SOL

BS Latency (us) Bytes Accessed BW (TB/s) SOL%
1 4.29 8,430,208 1.965 24.6%
2 5.41 16,860,032 3.116 39.0%
4 7.81 33,719,680 4.318 54.0%
8 13.15 67,438,976 5.128 64.1%
16 23.07 134,877,568 5.846 73.1%
32 42.72 269,754,752 6.314 78.9%
64 81.50 539,509,120 6.620 82.7%
128 158.94 1,079,017,856 6.789 84.9%
256 313.93 2,158,035,328 6.874 85.9%
512 630.19 4,316,070,272 6.849 85.6%

Here are the results of B200.

#36111 (comment) measured by you I am assuming is B200?

@mergify mergify bot removed the needs-rebase label Mar 12, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 13, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZJY0516.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 13, 2026
@ywang96
Copy link
Copy Markdown
Member

ywang96 commented Mar 19, 2026

Closing as superseded by #36596

@ywang96 ywang96 closed this Mar 19, 2026
@github-project-automation github-project-automation bot moved this from In review to Done in Qwen3.5 Mar 19, 2026
@vadiklyutiy
Copy link
Copy Markdown
Collaborator

Closing as superseded by #36596

I think this PR is a bit different changes...

@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Mar 19, 2026

Closing as superseded by #36596

I think this PR is a bit different changes...

But I found that #36596 is already so fast that it's hard to see the speedup from this kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase qwen Related to Qwen models

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants