Skip to content

[BUG] blackwell compile issue #1233

@createthis

Description

@createthis

Required prerequisites

What version of TileLang are you using?

0.1.6.post2+cu128.git9eaa708f

System information

/opt/conda/lib/python3.10/runpy.py:126: RuntimeWarning: 'torch.utils.collect_env' found in sys.modules after import of package 'torch.utils', but prior to execution of 'torch.utils.collect_env'; this may result in unpredictable behaviour warn(RuntimeWarning(msg)) Collecting environment information... PyTorch version: 2.9.0+cu128 Is debug build: False CUDA used to build PyTorch: 12.8 ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 4.1.2
Libc version: glibc-2.39

Python version: 3.10.19 (main, Oct 21 2025, 16:43:05) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-87-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Nvidia driver version: 570.124.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.7.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9355 32-Core Processor
CPU family: 26
Model: 2
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU(s) scaling MHz: 100%
CPU max MHz: 3550.0000
CPU min MHz: 1500.0000
BogoMIPS: 7099.85
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc amd_ibpb_ret arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca fsrm avx512_vp2intersect flush_l1d debug_swap
Virtualization: AMD-V
L1d cache: 3 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 64 MiB (64 instances)
L3 cache: 512 MiB (16 instances)
NUMA node(s): 8
NUMA node0 CPU(s): 0-7,64-71
NUMA node1 CPU(s): 8-15,72-79
NUMA node2 CPU(s): 16-23,80-87
NUMA node3 CPU(s): 24-31,88-95
NUMA node4 CPU(s): 32-39,96-103
NUMA node5 CPU(s): 40-47,104-111
NUMA node6 CPU(s): 48-55,112-119
NUMA node7 CPU(s): 56-63,120-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Not affected

Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] torch==2.9.0
[pip3] triton==3.5.0
[conda] numpy 2.2.6 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] torch 2.9.0 pypi_0 pypi
[conda] triton 3.5.0 pypi_0 pypi

Problem description

I built tilelang like this:

cmake -S . -B build -DUSE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=120 && cmake --build build -j
pip install -e . -v

inside the docker/Dockerfile.cu128 docker container.

I'm trying to run this custom script:

bench_indexer_tilelang.py

#!/usr/bin/env python3
import argparse
import torch

# TileLang example kernels for lightning indexer
from examples.deepseek_v32.fp8_lighting_indexer import (
    mqa_attn_return_logits,
    mqa_attn_return_logits_interface,
)

def bench_tl_indexer_wrapper(seq_len: int,
                             seq_len_kv: int,
                             heads: int = 4,
                             index_dim: int = 64,
                             iters: int = 50,
                             warmup: int = 5):
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for this benchmark.")

    device = torch.device("cuda")

    # Inputs
    q = torch.randn(seq_len, heads, index_dim, device=device, dtype=torch.float32)
    kv = torch.randn(seq_len_kv, index_dim, device=device, dtype=torch.float32)
    # Precompute kv_scales similar to reference: sqrt(mean(k^2)) along dim=-1
    kv_scales = kv.pow(2).mean(dim=-1).sqrt()
    weights = torch.randn(seq_len, heads, device=device, dtype=torch.float32)
    cu_seqlen_ks = torch.zeros(seq_len, dtype=torch.int32, device=device)
    cu_seqlen_ke = torch.full((seq_len,), seq_len_kv, dtype=torch.int32, device=device)

    # Warmup
    for _ in range(warmup):
        _ = mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke)
    torch.cuda.synchronize()

    # Timed
    times = []
    for _ in range(iters):
        t0 = torch.cuda.Event(enable_timing=True)
        t1 = torch.cuda.Event(enable_timing=True)
        t0.record()
        _ = mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke)
        t1.record()
        t1.synchronize()
        times.append(t0.elapsed_time(t1))  # ms

    avg_ms = sum(times) / len(times) if times else float("nan")
    print(f"[TILELANG_INDEXER] WRAPPER  S={seq_len} SKV={seq_len_kv} H={heads} D={index_dim} avg_ms={avg_ms:.3f} over {iters}")


def bench_tl_indexer_impl(seq_len: int,
                          seq_len_kv: int,
                          heads: int = 4,
                          index_dim: int = 64,
                          iters: int = 50,
                          warmup: int = 5):
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for this benchmark.")

    device = torch.device("cuda")

    # Compile kernel once
    kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim)

    # Inputs
    q = torch.randn(seq_len, heads, index_dim, device=device, dtype=torch.float32)
    kv = torch.randn(seq_len_kv, index_dim, device=device, dtype=torch.float32)
    kv_scales = kv.pow(2).mean(dim=-1).sqrt()
    weights = torch.randn(seq_len, heads, device=device, dtype=torch.float32)
    cu_seqlen_ks = torch.zeros(seq_len, dtype=torch.int32, device=device)
    cu_seqlen_ke = torch.full((seq_len,), seq_len_kv, dtype=torch.int32, device=device)
    logits = torch.empty(seq_len, seq_len_kv, device=device, dtype=torch.float32)

    # Warmup
    for _ in range(warmup):
        kernel(
            q.view(seq_len * heads, index_dim),
            kv,
            kv_scales,
            logits,
            weights,
            cu_seqlen_ks,
            cu_seqlen_ke,
        )
    torch.cuda.synchronize()

    # Timed
    times = []
    for _ in range(iters):
        t0 = torch.cuda.Event(enable_timing=True)
        t1 = torch.cuda.Event(enable_timing=True)
        t0.record()
        kernel(
            q.view(seq_len * heads, index_dim),
            kv,
            kv_scales,
            logits,
            weights,
            cu_seqlen_ks,
            cu_seqlen_ke,
        )
        t1.record()
        t1.synchronize()
        times.append(t0.elapsed_time(t1))  # ms

    avg_ms = sum(times) / len(times) if times else float("nan")
    print(f"[TILELANG_INDEXER]   IMPL  S={seq_len} SKV={seq_len_kv} H={heads} D={index_dim} avg_ms={avg_ms:.3f} over {iters}")


def parse_int_list(s: str):
    vals = []
    for part in s.split(','):
        part = part.strip()
        if not part:
            continue
        vals.append(int(part))
    return vals


def main():
    parser = argparse.ArgumentParser(description="Benchmark TileLang lightning indexer (DeepSeek V3.2)")
    parser.add_argument("--seq-lens", type=parse_int_list, default="4096,16384,163840",
                        help="Comma-separated sequence lengths S (default: 4096,16384,163840)")
    parser.add_argument("--kv-lens", type=parse_int_list, default=None,
                        help="Comma-separated KV lengths SKV; if omitted, uses seq-lens")
    parser.add_argument("--heads", type=int, default=4, help="Indexer heads H (default: 4)")
    parser.add_argument("--dim", type=int, default=64, help="Indexer dimension D (default: 64)")
    parser.add_argument("--iters", type=int, default=50, help="Timed iterations (default: 50)")
    parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations (default: 5)")
    parser.add_argument("--mode", choices=["both", "wrapper", "impl"], default="both",
                        help="Which path to benchmark: wrapper (interface), impl (kernel), or both (default)")
    args = parser.parse_args()

    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for this benchmark.")

    dev = torch.cuda.get_device_name(0)
    print(f"CUDA device: {dev}")

    seq_lens = args.seq_lens if isinstance(args.seq_lens, list) else parse_int_list(args.seq_lens)
    kv_lens = None
    if args.kv_lens is None:
        kv_lens = seq_lens
    else:
        kv_lens = args.kv_lens if isinstance(args.kv_lens, list) else parse_int_list(args.kv_lens)
        if len(kv_lens) != len(seq_lens):
            raise ValueError("--kv-lens must have the same number of elements as --seq-lens")

    for S, SKV in zip(seq_lens, kv_lens):
        if args.mode in ("both", "wrapper"):
            bench_tl_indexer_wrapper(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup)
        if args.mode in ("both", "impl"):
            bench_tl_indexer_impl(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup)


if __name__ == "__main__":
    main()

I'm getting this error:

2025-11-12 04:33:04  [TileLang:tilelang.env:WARNING]: Loading tilelang libs from dev root: /root/TileLang/build
CUDA device: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
2025-11-12 04:33:06  [TileLang:tilelang.language.v2.builder:WARNING]: Binding a for frame to variable may cause undefined behavior in tilelang.
Stack (most recent call last):
  File "/root/TileLang/bench_indexer_tilelang.py", line 157, in <module>
    main()
  File "/root/TileLang/bench_indexer_tilelang.py", line 151, in main
    bench_tl_indexer_wrapper(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup)
  File "/root/TileLang/bench_indexer_tilelang.py", line 33, in bench_tl_indexer_wrapper
    _ = mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke)
  File "/root/TileLang/examples/deepseek_v32/fp8_lighting_indexer.py", line 224, in mqa_attn_return_logits_interface
    clean_logits_kernel = clean_logits_()
  File "/root/TileLang/tilelang/jit/__init__.py", line 273, in __call__
    self._kernel_cache[key] = self.compile(*args, **kwargs, **tune_params)
  File "/root/TileLang/tilelang/jit/__init__.py", line 223, in compile
    func = self.get_tir(*args, **kwargs)
  File "/root/TileLang/tilelang/jit/__init__.py", line 192, in get_tir
    program_result = program_result_source(*args, **kwargs)
  File "/root/TileLang/examples/deepseek_v32/fp8_lighting_indexer.py", line 193, in clean_logits_
    def clean_logits_kernel(
  File "/root/TileLang/tilelang/language/v2/builder.py", line 720, in prim_func
    return impl(func) if func is not None else impl
  File "/root/TileLang/tilelang/language/v2/builder.py", line 707, in impl
    return prim_func_generator(**annot)
  File "/root/TileLang/tilelang/language/v2/builder.py", line 692, in prim_func_generator
    ir_gen.gen(builder)(*args, **kwargs)
  File "/root/TileLang/examples/deepseek_v32/fp8_lighting_indexer.py", line 199, in clean_logits_kernel
    tx = T.thread_binding(0, threads, thread="threadIdx.x")
  File "/root/TileLang/tilelang/language/v2/builder.py", line 329, in bind
    res = self.bind_immutable(name, value)
2025-11-12 04:33:06  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `mqa_attn_return_logits_kernel` with `out_idx=None`
[04:33:06] /root/TileLang/src/transform/legalize_negative_index.cc:55: Warning: LegalizeNegativeIndex: cannot prove non-negative index nbn_i * 256 + cu_k_s_min[0] for buffer IndexK (axis 0).
[04:33:06] /root/TileLang/src/transform/legalize_negative_index.cc:55: Warning: LegalizeNegativeIndex: cannot prove non-negative index nbn_i * 256 + cu_k_s_min[0] for buffer IndexKScale (axis 0).
/root/TileLang/3rdparty/../src/tl_templates/cuda/gemm_mma.h(289): error: incomplete type "cute::tl_mma::DispatchInstruction<cute::tl_mma::GemmTensorOp<256, 128, 64, 1, 16, false, true, true, 64, 64, 0, 0, fp8_e4_t, fp8_e4_t, float>::A_type, cute::tl_mma::GemmTensorOp<256, 128, 64, 1, 16, false, true, true, 64, 64, 0, 0, fp8_e4_t, fp8_e4_t, float>::B_type, cute::tl_mma::GemmTensorOp<256, 128, 64, 1, 16, false, true, true, 64, 64, 0, 0, fp8_e4_t, fp8_e4_t, float>::C_type, 1, 16, 128>" (aka "cute::tl_mma::DispatchInstruction<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float, 1, 16, 128>") is not allowed
    using TileMma = TiledMMA<typename Instruction::MMA,
                                      ^

Reproducible example code

The Python snippets:

Traceback

Expected behavior

No response

Additional context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions