Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 1 addition & 25 deletions examples/flash_decoding/example_gqa_decode_varlen_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import argparse
import tilelang
import tilelang.language as T
from tilelang.profiler import do_bench
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check the signature of do_bench in tilelang/profiler

echo "=== Looking for do_bench function definition ==="
rg -n "def do_bench" --type py -A 20

echo ""
echo "=== Search for tilelang/profiler module ==="
find . -name "profiler.py" -o -name "profiler" -type d | head -20

Repository: tile-ai/tilelang

Length of output: 3149


🏁 Script executed:

#!/bin/bash
# View the do_bench calls in the file

echo "=== Calls to do_bench around lines 685-699 ==="
sed -n '680,700p' examples/flash_decoding/example_gqa_decode_varlen_logits.py

echo ""
echo "=== Calls to do_bench around lines 704-716 ==="
sed -n '700,720p' examples/flash_decoding/example_gqa_decode_varlen_logits.py

Repository: tile-ai/tilelang

Length of output: 1305


Critical: do_bench calls will fail due to signature incompatibility.

The imported tilelang.profiler.do_bench function signature only accepts fn and keyword arguments (warmup, rep, _n_warmup, _n_repeat, quantiles, fast_flush, backend, return_mode). It does not support the *args calling convention used in this file.

The calls at lines 687-700 and 709-720 pass 11-12 positional arguments after the function:

do_bench(flash_attn_with_attn_pool_decode_tilelang, q_decode, k_varlen, v_varlen, cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink, block_size, False, tl_kernel)

This will raise TypeError: got unexpected positional arguments at runtime. Wrap each function call with functools.partial to bind the arguments, or refactor to use keyword arguments if the benchmarked function supports them.

🤖 Prompt for AI Agents
In `@examples/flash_decoding/example_gqa_decode_varlen_logits.py` at line 8,
do_bench is being called with positional arguments that it doesn't accept; wrap
the benchmark target and its positional args using functools.partial (or a
lambda) so do_bench receives a single callable and only keyword args itself.
Locate the do_bench calls that pass flash_attn_with_attn_pool_decode_tilelang
(and similarly any other flash_attn* benchmarks) and change them to
do_bench(functools.partial(flash_attn_with_attn_pool_decode_tilelang, q_decode,
k_varlen, v_varlen, cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale,
sink, block_size, False, tl_kernel), warmup=..., rep=..., ...) or equivalent,
ensuring you import functools.partial and preserve the existing do_bench keyword
parameters.


torch.manual_seed(0)

Expand Down Expand Up @@ -617,31 +618,6 @@ def test_varlen_decode_main(args):
print("✅ All tests passed!")


def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)

torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()

# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)

return times.mean().item()


def speed_benchmark_decode_comparison(args):
"""Speed benchmark for decode kernel"""
batch_size = args.batch_size
Expand Down
26 changes: 1 addition & 25 deletions examples/gdn/example_chunk_delta_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import tilelang
import tilelang.language as T
from tilelang.profiler import do_bench
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: do_bench call signatures are incompatible with the centralized function.

The import change introduces a breaking issue. At lines 480-481, do_bench is called with tensor arguments:

fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)

The centralized do_bench from tilelang.profiler.bench expects fn, warmup, rep, _n_warmup, _n_repeat, ... — not input tensors. The tensors will be incorrectly interpreted as timing parameters.

The calls should wrap the function and its arguments in a lambda or functools.partial:

🐛 Proposed fix
-    fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
-    tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)
+    fla_time = do_bench(lambda: chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size))
+    tilelang_time = do_bench(lambda: kernel(Q, K, W, G, h0, dht, dO, dv))
🤖 Prompt for AI Agents
In `@examples/gdn/example_chunk_delta_bwd.py` at line 7, The do_bench import now
points to the centralized profiler which expects a callable plus timing params,
but the current calls pass tensor args directly (see do_bench calls with
chunk_gated_delta_rule_bwd_dhu and kernel and tensors Q, K, W, G, h0, dht, dO,
dv, scale, chunk_size), causing the tensors to be parsed as timing parameters;
fix by wrapping the target function and its arguments into a zero-arg callable
(e.g., use a lambda or functools.partial) so do_bench receives a single callable
and pass chunk_size as a keyword inside that wrapper or via partial, e.g. wrap
chunk_gated_delta_rule_bwd_dhu with its tensors and chunk_size and similarly
wrap kernel before calling do_bench.


print(tilelang.__file__, flush=True)

Expand Down Expand Up @@ -544,31 +545,6 @@ def run_test(
assert_similar(dv2_ref_torch, dv2_tilelang, 1e-5, "torch-tilelang", data="dv2")


def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)

torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()

# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)

return times.mean().item()


def main():
DK = 128
run_test(
Expand Down
26 changes: 1 addition & 25 deletions examples/gdn/example_chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tilelang
import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, find the profiler module and examine do_bench signature
find . -type f -name "bench.py" | grep -E "profiler|tilelang"

Repository: tile-ai/tilelang

Length of output: 88


🏁 Script executed:

#!/bin/bash
# Explore the repository structure
ls -la

Repository: tile-ai/tilelang

Length of output: 2689


🏁 Script executed:

#!/bin/bash
# Check if tilelang/profiler exists and what's in it
find . -type d -name "profiler" 2>/dev/null | head -5

Repository: tile-ai/tilelang

Length of output: 105


🏁 Script executed:

#!/bin/bash
# Find all Python files in tilelang directory
find . -path "./tilelang/profiler*" -type f -name "*.py" 2>/dev/null

Repository: tile-ai/tilelang

Length of output: 120


🏁 Script executed:

#!/bin/bash
# Examine the example file to see the actual do_bench usage
if [ -f "examples/gdn/example_chunk_delta_h.py" ]; then
  echo "=== Lines 300-320 of example_chunk_delta_h.py ==="
  sed -n '300,320p' examples/gdn/example_chunk_delta_h.py
fi

Repository: tile-ai/tilelang

Length of output: 830


🏁 Script executed:

#!/bin/bash
# Examine the do_bench function signature in bench.py
cat -n ./tilelang/profiler/bench.py | head -50

Repository: tile-ai/tilelang

Length of output: 2029


🏁 Script executed:

#!/bin/bash
# Get the full function signature
ast-grep --pattern $'def do_bench($$$):
    $$$'

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

#!/bin/bash
# Continue reading bench.py to find do_bench function
cat -n ./tilelang/profiler/bench.py | tail -n +50 | head -100

Repository: tile-ai/tilelang

Length of output: 4530


🏁 Script executed:

#!/bin/bash
# Search for the do_bench function definition
rg "def do_bench" -A 10 ./tilelang/profiler/bench.py

Repository: tile-ai/tilelang

Length of output: 399


🏁 Script executed:

#!/bin/bash
# Check git history to see the previous version
git log --oneline -5 HEAD

Repository: tile-ai/tilelang

Length of output: 151


🏁 Script executed:

#!/bin/bash
# Get the previous commit and show the file from before
git show HEAD~1:examples/gdn/example_chunk_delta_h.py 2>/dev/null | head -320 | tail -50

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

#!/bin/bash
# Check the diff of the current file in HEAD
git diff HEAD~1 examples/gdn/example_chunk_delta_h.py 2>/dev/null | head -100

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

#!/bin/bash
# Check git log to understand the history better
git log --oneline examples/gdn/example_chunk_delta_h.py 2>/dev/null | head -10

Repository: tile-ai/tilelang

Length of output: 151


🏁 Script executed:

#!/bin/bash
# Check what files were changed in the HEAD commit
git show --stat HEAD

Repository: tile-ai/tilelang

Length of output: 50372


🏁 Script executed:

#!/bin/bash
# Show the diff of the specific file
git show HEAD:examples/gdn/example_chunk_delta_h.py 2>/dev/null | sed -n '1,20p'

Repository: tile-ai/tilelang

Length of output: 669


🏁 Script executed:

#!/bin/bash
# Check the previous commit for that file
git show HEAD~1:examples/gdn/example_chunk_delta_h.py 2>/dev/null | sed -n '1,320p' | tail -50

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

#!/bin/bash
# Check what the previous version of the file looked like (look further back in history)
git log --all --pretty=oneline examples/gdn/example_chunk_delta_h.py 2>/dev/null | head -20

Repository: tile-ai/tilelang

Length of output: 1134


🏁 Script executed:

#!/bin/bash
# Look at the previous commit to see if there was a local do_bench definition
git show 0f980f15c575bf35db73a70fc04a8a53c005b2c8:examples/gdn/example_chunk_delta_h.py 2>/dev/null | sed -n '1,50p'

Repository: tile-ai/tilelang

Length of output: 1466


🏁 Script executed:

#!/bin/bash
# Check if the previous version had a local do_bench function
git show 0f980f15c575bf35db73a70fc04a8a53c005b2c8:examples/gdn/example_chunk_delta_h.py 2>/dev/null | grep -A 30 "def do_bench"

Repository: tile-ai/tilelang

Length of output: 778


Fix do_bench call signatures incompatible with the centralized function.

The import change is correct, but the calls at lines 302-313 are incompatible with the centralized do_bench signature.

The previous local implementation accepted do_bench(fn, *args, warmup=10, rep=10, **kwargs) and passed *args and **kwargs to the function. The centralized version from tilelang.profiler.bench has signature do_bench(fn, warmup, rep, _n_warmup, _n_repeat, quantiles, fast_flush, backend, return_mode) and calls fn() with no arguments.

Lines 302-312 call do_bench(chunk_gated_delta_rule_fwd_h, k=K, w=W, u=U, ...) with keyword arguments that don't exist in the new signature, causing TypeError: unexpected keyword arguments.

Line 313 calls do_bench(kernel, K, W, U, G, initial_state) with tensor values as positional arguments, which will be misinterpreted as warmup, rep, _n_warmup, _n_repeat, and quantiles parameters with type mismatches.

Both calls need to be refactored to wrap the function invocations appropriately for the new centralized do_bench interface.

🤖 Prompt for AI Agents
In `@examples/gdn/example_chunk_delta_h.py` at line 7, The calls to do_bench must
be adapted to the centralized signature that invokes fn() with no args: wrap the
target functions (chunk_gated_delta_rule_fwd_h and kernel) into zero-argument
callables (e.g., lambda or functools.partial) that capture K, W, U, G,
initial_state and any other inputs, and then call do_bench with explicit
benchmarking parameters (warmup, rep, _n_warmup, _n_repeat, quantiles,
fast_flush, backend, return_mode) rather than passing tensors as
positional/keyword args; update the two sites where do_bench is invoked so they
pass a zero-arg wrapper and appropriate numeric/flag values for the benchmark
options.


# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
Expand Down Expand Up @@ -224,31 +225,6 @@ def kernel(
return kernel


def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)

torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()

# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)

return times.mean().item()


def run_test(
B,
S,
Expand Down
25 changes: 0 additions & 25 deletions examples/gdn/example_chunk_o_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,31 +359,6 @@ def kernel(
return kernel


def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)

torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()

# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)

return times.mean().item()


def run_test(
B,
S,
Expand Down
2 changes: 1 addition & 1 deletion examples/gemm_sp/example_custom_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tilelang.utils.sparse import randn_semi_sparse
from tilelang.utils.tensor import torch_assert_close

from triton.testing import do_bench
from tilelang.profiler import do_bench

import torch

Expand Down
2 changes: 1 addition & 1 deletion examples/gemm_sp/example_gemm_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.sparse import compress, randn_semi_sparse
from tilelang.contrib import nvcc
from triton.testing import do_bench
from tilelang.profiler import do_bench

import torch

Expand Down
67 changes: 15 additions & 52 deletions tilelang/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from typing import Callable, Any, Literal
from functools import partial
import torch
from contextlib import suppress
from dataclasses import dataclass
import tvm
from tilelang.utils.tensor import (
get_tensor_supply,
TensorSupplyType,
Expand Down Expand Up @@ -191,21 +189,6 @@ def run_once(self, func: Callable | None = None):
func = self.__call__
return func(*ins)

def determine_profiler(self, func: Callable | None = None):
"""Determines which profiler backend to use based on function type.

Args:
func: Function to be profiled
profiler: Explicitly specified profiler type or "auto" for automatic detection

Returns:
str: The determined profiler type ("torch" or "tvm")
"""
if isinstance(func, tvm.runtime.Module):
return "tvm"
else:
return "torch"

def do_bench(
self,
func: Callable | None = None,
Expand All @@ -232,41 +215,21 @@ def do_bench(
Returns:
float: Average execution time in milliseconds
"""
profiler = self.determine_profiler(func)
if profiler == "torch":
if func is None:
assert self.adapter is not None, "benchmarking function should be provided"
func = self.adapter
ins = self._get_inputs() if input_tensors is None else input_tensors
bench_func = partial(func, *ins)
return do_bench(
bench_func,
warmup=warmup,
rep=rep,
_n_warmup=n_warmup,
_n_repeat=n_repeat,
quantiles=quantiles,
backend=backend,
return_mode=return_mode,
)
elif profiler == "tvm":
assert func is not None, "func should not be None"
assert isinstance(func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}"

ins = self._get_inputs(with_output=True) if input_tensors is None else input_tensors
target = "cuda"

with suppress(Exception):
target = self.mod.imported_modules[0].type_key

assert target in ["cuda", "hip"], f"Unknown target: {target}"

device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0)
time_evaluator = self.mod.time_evaluator(self.mod.entry_name, device, number=rep, repeat=n_repeat)
# Transform Latency to ms
return time_evaluator(*ins).mean * 1e3
else:
raise ValueError(f"Unknown profiler: {profiler}")
if func is None:
assert self.adapter is not None, "benchmarking function should be provided"
func = self.adapter
ins = self._get_inputs() if input_tensors is None else input_tensors
bench_func = partial(func, *ins)
return do_bench(
bench_func,
warmup=warmup,
rep=rep,
_n_warmup=n_warmup,
_n_repeat=n_repeat,
quantiles=quantiles,
backend=backend,
return_mode=return_mode,
)

@property
def func(self):
Expand Down
Loading