-
Notifications
You must be signed in to change notification settings - Fork 337
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.6.post2+cu128.git9eaa708f
System information
OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 4.1.2
Libc version: glibc-2.39
Python version: 3.10.19 (main, Oct 21 2025, 16:43:05) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-87-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Nvidia driver version: 570.124.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.7.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9355 32-Core Processor
CPU family: 26
Model: 2
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU(s) scaling MHz: 100%
CPU max MHz: 3550.0000
CPU min MHz: 1500.0000
BogoMIPS: 7099.85
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc amd_ibpb_ret arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca fsrm avx512_vp2intersect flush_l1d debug_swap
Virtualization: AMD-V
L1d cache: 3 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 64 MiB (64 instances)
L3 cache: 512 MiB (16 instances)
NUMA node(s): 8
NUMA node0 CPU(s): 0-7,64-71
NUMA node1 CPU(s): 8-15,72-79
NUMA node2 CPU(s): 16-23,80-87
NUMA node3 CPU(s): 24-31,88-95
NUMA node4 CPU(s): 32-39,96-103
NUMA node5 CPU(s): 40-47,104-111
NUMA node6 CPU(s): 48-55,112-119
NUMA node7 CPU(s): 56-63,120-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Not affected
Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] torch==2.9.0
[pip3] triton==3.5.0
[conda] numpy 2.2.6 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] torch 2.9.0 pypi_0 pypi
[conda] triton 3.5.0 pypi_0 pypi
Problem description
I built tilelang like this:
cmake -S . -B build -DUSE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=120 && cmake --build build -j
pip install -e . -v
inside the docker/Dockerfile.cu128 docker container.
I'm trying to run this custom script:
bench_indexer_tilelang.py
#!/usr/bin/env python3
import argparse
import torch
# TileLang example kernels for lightning indexer
from examples.deepseek_v32.fp8_lighting_indexer import (
mqa_attn_return_logits,
mqa_attn_return_logits_interface,
)
def bench_tl_indexer_wrapper(seq_len: int,
seq_len_kv: int,
heads: int = 4,
index_dim: int = 64,
iters: int = 50,
warmup: int = 5):
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
device = torch.device("cuda")
# Inputs
q = torch.randn(seq_len, heads, index_dim, device=device, dtype=torch.float32)
kv = torch.randn(seq_len_kv, index_dim, device=device, dtype=torch.float32)
# Precompute kv_scales similar to reference: sqrt(mean(k^2)) along dim=-1
kv_scales = kv.pow(2).mean(dim=-1).sqrt()
weights = torch.randn(seq_len, heads, device=device, dtype=torch.float32)
cu_seqlen_ks = torch.zeros(seq_len, dtype=torch.int32, device=device)
cu_seqlen_ke = torch.full((seq_len,), seq_len_kv, dtype=torch.int32, device=device)
# Warmup
for _ in range(warmup):
_ = mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke)
torch.cuda.synchronize()
# Timed
times = []
for _ in range(iters):
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t0.record()
_ = mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke)
t1.record()
t1.synchronize()
times.append(t0.elapsed_time(t1)) # ms
avg_ms = sum(times) / len(times) if times else float("nan")
print(f"[TILELANG_INDEXER] WRAPPER S={seq_len} SKV={seq_len_kv} H={heads} D={index_dim} avg_ms={avg_ms:.3f} over {iters}")
def bench_tl_indexer_impl(seq_len: int,
seq_len_kv: int,
heads: int = 4,
index_dim: int = 64,
iters: int = 50,
warmup: int = 5):
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
device = torch.device("cuda")
# Compile kernel once
kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim)
# Inputs
q = torch.randn(seq_len, heads, index_dim, device=device, dtype=torch.float32)
kv = torch.randn(seq_len_kv, index_dim, device=device, dtype=torch.float32)
kv_scales = kv.pow(2).mean(dim=-1).sqrt()
weights = torch.randn(seq_len, heads, device=device, dtype=torch.float32)
cu_seqlen_ks = torch.zeros(seq_len, dtype=torch.int32, device=device)
cu_seqlen_ke = torch.full((seq_len,), seq_len_kv, dtype=torch.int32, device=device)
logits = torch.empty(seq_len, seq_len_kv, device=device, dtype=torch.float32)
# Warmup
for _ in range(warmup):
kernel(
q.view(seq_len * heads, index_dim),
kv,
kv_scales,
logits,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
)
torch.cuda.synchronize()
# Timed
times = []
for _ in range(iters):
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t0.record()
kernel(
q.view(seq_len * heads, index_dim),
kv,
kv_scales,
logits,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
)
t1.record()
t1.synchronize()
times.append(t0.elapsed_time(t1)) # ms
avg_ms = sum(times) / len(times) if times else float("nan")
print(f"[TILELANG_INDEXER] IMPL S={seq_len} SKV={seq_len_kv} H={heads} D={index_dim} avg_ms={avg_ms:.3f} over {iters}")
def parse_int_list(s: str):
vals = []
for part in s.split(','):
part = part.strip()
if not part:
continue
vals.append(int(part))
return vals
def main():
parser = argparse.ArgumentParser(description="Benchmark TileLang lightning indexer (DeepSeek V3.2)")
parser.add_argument("--seq-lens", type=parse_int_list, default="4096,16384,163840",
help="Comma-separated sequence lengths S (default: 4096,16384,163840)")
parser.add_argument("--kv-lens", type=parse_int_list, default=None,
help="Comma-separated KV lengths SKV; if omitted, uses seq-lens")
parser.add_argument("--heads", type=int, default=4, help="Indexer heads H (default: 4)")
parser.add_argument("--dim", type=int, default=64, help="Indexer dimension D (default: 64)")
parser.add_argument("--iters", type=int, default=50, help="Timed iterations (default: 50)")
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations (default: 5)")
parser.add_argument("--mode", choices=["both", "wrapper", "impl"], default="both",
help="Which path to benchmark: wrapper (interface), impl (kernel), or both (default)")
args = parser.parse_args()
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
dev = torch.cuda.get_device_name(0)
print(f"CUDA device: {dev}")
seq_lens = args.seq_lens if isinstance(args.seq_lens, list) else parse_int_list(args.seq_lens)
kv_lens = None
if args.kv_lens is None:
kv_lens = seq_lens
else:
kv_lens = args.kv_lens if isinstance(args.kv_lens, list) else parse_int_list(args.kv_lens)
if len(kv_lens) != len(seq_lens):
raise ValueError("--kv-lens must have the same number of elements as --seq-lens")
for S, SKV in zip(seq_lens, kv_lens):
if args.mode in ("both", "wrapper"):
bench_tl_indexer_wrapper(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup)
if args.mode in ("both", "impl"):
bench_tl_indexer_impl(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup)
if __name__ == "__main__":
main()I'm getting this error:
2025-11-12 04:33:04 [TileLang:tilelang.env:WARNING]: Loading tilelang libs from dev root: /root/TileLang/build
CUDA device: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
2025-11-12 04:33:06 [TileLang:tilelang.language.v2.builder:WARNING]: Binding a for frame to variable may cause undefined behavior in tilelang.
Stack (most recent call last):
File "/root/TileLang/bench_indexer_tilelang.py", line 157, in <module>
main()
File "/root/TileLang/bench_indexer_tilelang.py", line 151, in main
bench_tl_indexer_wrapper(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup)
File "/root/TileLang/bench_indexer_tilelang.py", line 33, in bench_tl_indexer_wrapper
_ = mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke)
File "/root/TileLang/examples/deepseek_v32/fp8_lighting_indexer.py", line 224, in mqa_attn_return_logits_interface
clean_logits_kernel = clean_logits_()
File "/root/TileLang/tilelang/jit/__init__.py", line 273, in __call__
self._kernel_cache[key] = self.compile(*args, **kwargs, **tune_params)
File "/root/TileLang/tilelang/jit/__init__.py", line 223, in compile
func = self.get_tir(*args, **kwargs)
File "/root/TileLang/tilelang/jit/__init__.py", line 192, in get_tir
program_result = program_result_source(*args, **kwargs)
File "/root/TileLang/examples/deepseek_v32/fp8_lighting_indexer.py", line 193, in clean_logits_
def clean_logits_kernel(
File "/root/TileLang/tilelang/language/v2/builder.py", line 720, in prim_func
return impl(func) if func is not None else impl
File "/root/TileLang/tilelang/language/v2/builder.py", line 707, in impl
return prim_func_generator(**annot)
File "/root/TileLang/tilelang/language/v2/builder.py", line 692, in prim_func_generator
ir_gen.gen(builder)(*args, **kwargs)
File "/root/TileLang/examples/deepseek_v32/fp8_lighting_indexer.py", line 199, in clean_logits_kernel
tx = T.thread_binding(0, threads, thread="threadIdx.x")
File "/root/TileLang/tilelang/language/v2/builder.py", line 329, in bind
res = self.bind_immutable(name, value)
2025-11-12 04:33:06 [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `mqa_attn_return_logits_kernel` with `out_idx=None`
[04:33:06] /root/TileLang/src/transform/legalize_negative_index.cc:55: Warning: LegalizeNegativeIndex: cannot prove non-negative index nbn_i * 256 + cu_k_s_min[0] for buffer IndexK (axis 0).
[04:33:06] /root/TileLang/src/transform/legalize_negative_index.cc:55: Warning: LegalizeNegativeIndex: cannot prove non-negative index nbn_i * 256 + cu_k_s_min[0] for buffer IndexKScale (axis 0).
/root/TileLang/3rdparty/../src/tl_templates/cuda/gemm_mma.h(289): error: incomplete type "cute::tl_mma::DispatchInstruction<cute::tl_mma::GemmTensorOp<256, 128, 64, 1, 16, false, true, true, 64, 64, 0, 0, fp8_e4_t, fp8_e4_t, float>::A_type, cute::tl_mma::GemmTensorOp<256, 128, 64, 1, 16, false, true, true, 64, 64, 0, 0, fp8_e4_t, fp8_e4_t, float>::B_type, cute::tl_mma::GemmTensorOp<256, 128, 64, 1, 16, false, true, true, 64, 64, 0, 0, fp8_e4_t, fp8_e4_t, float>::C_type, 1, 16, 128>" (aka "cute::tl_mma::DispatchInstruction<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float, 1, 16, 128>") is not allowed
using TileMma = TiledMMA<typename Instruction::MMA,
^
Reproducible example code
The Python snippets:
Traceback
Expected behavior
No response
Additional context
No response