-
Notifications
You must be signed in to change notification settings - Fork 449
Refactor: Use centralized do_bench from tilelang.profiler #1670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…mport - Removed local implementations of the `do_bench` function from multiple example files. - Updated imports to use the centralized `do_bench` function from `tilelang.profiler`, promoting code reuse and consistency across examples.
…nd __init__.py - Eliminated the unused `do_bench` import from `example_chunk_o_bwd.py` and `tilelang/profiler/__init__.py`, streamlining the code and improving clarity.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThe PR centralizes benchmarking by removing local Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@examples/flash_decoding/example_gqa_decode_varlen_logits.py`:
- Line 8: do_bench is being called with positional arguments that it doesn't
accept; wrap the benchmark target and its positional args using
functools.partial (or a lambda) so do_bench receives a single callable and only
keyword args itself. Locate the do_bench calls that pass
flash_attn_with_attn_pool_decode_tilelang (and similarly any other flash_attn*
benchmarks) and change them to
do_bench(functools.partial(flash_attn_with_attn_pool_decode_tilelang, q_decode,
k_varlen, v_varlen, cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale,
sink, block_size, False, tl_kernel), warmup=..., rep=..., ...) or equivalent,
ensuring you import functools.partial and preserve the existing do_bench keyword
parameters.
In `@examples/gdn/example_chunk_delta_bwd.py`:
- Line 7: The do_bench import now points to the centralized profiler which
expects a callable plus timing params, but the current calls pass tensor args
directly (see do_bench calls with chunk_gated_delta_rule_bwd_dhu and kernel and
tensors Q, K, W, G, h0, dht, dO, dv, scale, chunk_size), causing the tensors to
be parsed as timing parameters; fix by wrapping the target function and its
arguments into a zero-arg callable (e.g., use a lambda or functools.partial) so
do_bench receives a single callable and pass chunk_size as a keyword inside that
wrapper or via partial, e.g. wrap chunk_gated_delta_rule_bwd_dhu with its
tensors and chunk_size and similarly wrap kernel before calling do_bench.
In `@examples/gdn/example_chunk_delta_h.py`:
- Line 7: The calls to do_bench must be adapted to the centralized signature
that invokes fn() with no args: wrap the target functions
(chunk_gated_delta_rule_fwd_h and kernel) into zero-argument callables (e.g.,
lambda or functools.partial) that capture K, W, U, G, initial_state and any
other inputs, and then call do_bench with explicit benchmarking parameters
(warmup, rep, _n_warmup, _n_repeat, quantiles, fast_flush, backend, return_mode)
rather than passing tensors as positional/keyword args; update the two sites
where do_bench is invoked so they pass a zero-arg wrapper and appropriate
numeric/flag values for the benchmark options.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
examples/flash_decoding/example_gqa_decode_varlen_logits.pyexamples/gdn/example_chunk_delta_bwd.pyexamples/gdn/example_chunk_delta_h.pyexamples/gdn/example_chunk_o_bwd.pyexamples/gemm_sp/example_custom_compress.pyexamples/gemm_sp/example_gemm_sp.pytilelang/profiler/__init__.py
💤 Files with no reviewable changes (1)
- examples/gdn/example_chunk_o_bwd.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/gemm_sp/example_custom_compress.pyexamples/gemm_sp/example_gemm_sp.py
🧬 Code graph analysis (5)
examples/gdn/example_chunk_delta_bwd.py (1)
tilelang/profiler/__init__.py (1)
do_bench(193-233)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (1)
tilelang/profiler/__init__.py (1)
do_bench(193-233)
examples/gemm_sp/example_custom_compress.py (1)
tilelang/profiler/__init__.py (1)
do_bench(193-233)
examples/gemm_sp/example_gemm_sp.py (1)
tilelang/profiler/__init__.py (1)
do_bench(193-233)
tilelang/profiler/__init__.py (1)
tilelang/profiler/bench.py (1)
do_bench(64-135)
🔇 Additional comments (3)
tilelang/profiler/__init__.py (1)
219-233: LGTM! Clean simplification of the benchmarking logic.The refactored
do_benchmethod correctly delegates to the centralizeddo_benchfunction fromtilelang.profiler.bench. The parameter mapping is correct, and the assertion ensures an adapter is provided when no explicit function is passed.examples/gemm_sp/example_gemm_sp.py (1)
9-9: LGTM! Import updated correctly.The import change from
triton.testingtotilelang.profileraligns with the PR objective. The usage at lines 115-116 correctly wraps kernel calls in lambdas, making them compatible with the centralizeddo_benchsignature.examples/gemm_sp/example_custom_compress.py (1)
10-10: LGTM! Import updated correctly.The import change from
triton.testingtotilelang.profileraligns with the PR objective. The usage at lines 314-315 correctly wraps kernel calls in lambdas, compatible with the centralizeddo_benchsignature.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| import argparse | ||
| import tilelang | ||
| import tilelang.language as T | ||
| from tilelang.profiler import do_bench |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check the signature of do_bench in tilelang/profiler
echo "=== Looking for do_bench function definition ==="
rg -n "def do_bench" --type py -A 20
echo ""
echo "=== Search for tilelang/profiler module ==="
find . -name "profiler.py" -o -name "profiler" -type d | head -20Repository: tile-ai/tilelang
Length of output: 3149
🏁 Script executed:
#!/bin/bash
# View the do_bench calls in the file
echo "=== Calls to do_bench around lines 685-699 ==="
sed -n '680,700p' examples/flash_decoding/example_gqa_decode_varlen_logits.py
echo ""
echo "=== Calls to do_bench around lines 704-716 ==="
sed -n '700,720p' examples/flash_decoding/example_gqa_decode_varlen_logits.pyRepository: tile-ai/tilelang
Length of output: 1305
Critical: do_bench calls will fail due to signature incompatibility.
The imported tilelang.profiler.do_bench function signature only accepts fn and keyword arguments (warmup, rep, _n_warmup, _n_repeat, quantiles, fast_flush, backend, return_mode). It does not support the *args calling convention used in this file.
The calls at lines 687-700 and 709-720 pass 11-12 positional arguments after the function:
do_bench(flash_attn_with_attn_pool_decode_tilelang, q_decode, k_varlen, v_varlen, cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink, block_size, False, tl_kernel)This will raise TypeError: got unexpected positional arguments at runtime. Wrap each function call with functools.partial to bind the arguments, or refactor to use keyword arguments if the benchmarked function supports them.
🤖 Prompt for AI Agents
In `@examples/flash_decoding/example_gqa_decode_varlen_logits.py` at line 8,
do_bench is being called with positional arguments that it doesn't accept; wrap
the benchmark target and its positional args using functools.partial (or a
lambda) so do_bench receives a single callable and only keyword args itself.
Locate the do_bench calls that pass flash_attn_with_attn_pool_decode_tilelang
(and similarly any other flash_attn* benchmarks) and change them to
do_bench(functools.partial(flash_attn_with_attn_pool_decode_tilelang, q_decode,
k_varlen, v_varlen, cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale,
sink, block_size, False, tl_kernel), warmup=..., rep=..., ...) or equivalent,
ensuring you import functools.partial and preserve the existing do_bench keyword
parameters.
|
|
||
| import tilelang | ||
| import tilelang.language as T | ||
| from tilelang.profiler import do_bench |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: do_bench call signatures are incompatible with the centralized function.
The import change introduces a breaking issue. At lines 480-481, do_bench is called with tensor arguments:
fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)The centralized do_bench from tilelang.profiler.bench expects fn, warmup, rep, _n_warmup, _n_repeat, ... — not input tensors. The tensors will be incorrectly interpreted as timing parameters.
The calls should wrap the function and its arguments in a lambda or functools.partial:
🐛 Proposed fix
- fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
- tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)
+ fla_time = do_bench(lambda: chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size))
+ tilelang_time = do_bench(lambda: kernel(Q, K, W, G, h0, dht, dO, dv))🤖 Prompt for AI Agents
In `@examples/gdn/example_chunk_delta_bwd.py` at line 7, The do_bench import now
points to the centralized profiler which expects a callable plus timing params,
but the current calls pass tensor args directly (see do_bench calls with
chunk_gated_delta_rule_bwd_dhu and kernel and tensors Q, K, W, G, h0, dht, dO,
dv, scale, chunk_size), causing the tensors to be parsed as timing parameters;
fix by wrapping the target function and its arguments into a zero-arg callable
(e.g., use a lambda or functools.partial) so do_bench receives a single callable
and pass chunk_size as a keyword inside that wrapper or via partial, e.g. wrap
chunk_gated_delta_rule_bwd_dhu with its tensors and chunk_size and similarly
wrap kernel before calling do_bench.
| import tilelang | ||
| import tilelang.language as T | ||
| from tilelang.autotuner import autotune | ||
| from tilelang.profiler import do_bench |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, find the profiler module and examine do_bench signature
find . -type f -name "bench.py" | grep -E "profiler|tilelang"Repository: tile-ai/tilelang
Length of output: 88
🏁 Script executed:
#!/bin/bash
# Explore the repository structure
ls -laRepository: tile-ai/tilelang
Length of output: 2689
🏁 Script executed:
#!/bin/bash
# Check if tilelang/profiler exists and what's in it
find . -type d -name "profiler" 2>/dev/null | head -5Repository: tile-ai/tilelang
Length of output: 105
🏁 Script executed:
#!/bin/bash
# Find all Python files in tilelang directory
find . -path "./tilelang/profiler*" -type f -name "*.py" 2>/dev/nullRepository: tile-ai/tilelang
Length of output: 120
🏁 Script executed:
#!/bin/bash
# Examine the example file to see the actual do_bench usage
if [ -f "examples/gdn/example_chunk_delta_h.py" ]; then
echo "=== Lines 300-320 of example_chunk_delta_h.py ==="
sed -n '300,320p' examples/gdn/example_chunk_delta_h.py
fiRepository: tile-ai/tilelang
Length of output: 830
🏁 Script executed:
#!/bin/bash
# Examine the do_bench function signature in bench.py
cat -n ./tilelang/profiler/bench.py | head -50Repository: tile-ai/tilelang
Length of output: 2029
🏁 Script executed:
#!/bin/bash
# Get the full function signature
ast-grep --pattern $'def do_bench($$$):
$$$'Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
#!/bin/bash
# Continue reading bench.py to find do_bench function
cat -n ./tilelang/profiler/bench.py | tail -n +50 | head -100Repository: tile-ai/tilelang
Length of output: 4530
🏁 Script executed:
#!/bin/bash
# Search for the do_bench function definition
rg "def do_bench" -A 10 ./tilelang/profiler/bench.pyRepository: tile-ai/tilelang
Length of output: 399
🏁 Script executed:
#!/bin/bash
# Check git history to see the previous version
git log --oneline -5 HEADRepository: tile-ai/tilelang
Length of output: 151
🏁 Script executed:
#!/bin/bash
# Get the previous commit and show the file from before
git show HEAD~1:examples/gdn/example_chunk_delta_h.py 2>/dev/null | head -320 | tail -50Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
#!/bin/bash
# Check the diff of the current file in HEAD
git diff HEAD~1 examples/gdn/example_chunk_delta_h.py 2>/dev/null | head -100Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
#!/bin/bash
# Check git log to understand the history better
git log --oneline examples/gdn/example_chunk_delta_h.py 2>/dev/null | head -10Repository: tile-ai/tilelang
Length of output: 151
🏁 Script executed:
#!/bin/bash
# Check what files were changed in the HEAD commit
git show --stat HEADRepository: tile-ai/tilelang
Length of output: 50372
🏁 Script executed:
#!/bin/bash
# Show the diff of the specific file
git show HEAD:examples/gdn/example_chunk_delta_h.py 2>/dev/null | sed -n '1,20p'Repository: tile-ai/tilelang
Length of output: 669
🏁 Script executed:
#!/bin/bash
# Check the previous commit for that file
git show HEAD~1:examples/gdn/example_chunk_delta_h.py 2>/dev/null | sed -n '1,320p' | tail -50Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
#!/bin/bash
# Check what the previous version of the file looked like (look further back in history)
git log --all --pretty=oneline examples/gdn/example_chunk_delta_h.py 2>/dev/null | head -20Repository: tile-ai/tilelang
Length of output: 1134
🏁 Script executed:
#!/bin/bash
# Look at the previous commit to see if there was a local do_bench definition
git show 0f980f15c575bf35db73a70fc04a8a53c005b2c8:examples/gdn/example_chunk_delta_h.py 2>/dev/null | sed -n '1,50p'Repository: tile-ai/tilelang
Length of output: 1466
🏁 Script executed:
#!/bin/bash
# Check if the previous version had a local do_bench function
git show 0f980f15c575bf35db73a70fc04a8a53c005b2c8:examples/gdn/example_chunk_delta_h.py 2>/dev/null | grep -A 30 "def do_bench"Repository: tile-ai/tilelang
Length of output: 778
Fix do_bench call signatures incompatible with the centralized function.
The import change is correct, but the calls at lines 302-313 are incompatible with the centralized do_bench signature.
The previous local implementation accepted do_bench(fn, *args, warmup=10, rep=10, **kwargs) and passed *args and **kwargs to the function. The centralized version from tilelang.profiler.bench has signature do_bench(fn, warmup, rep, _n_warmup, _n_repeat, quantiles, fast_flush, backend, return_mode) and calls fn() with no arguments.
Lines 302-312 call do_bench(chunk_gated_delta_rule_fwd_h, k=K, w=W, u=U, ...) with keyword arguments that don't exist in the new signature, causing TypeError: unexpected keyword arguments.
Line 313 calls do_bench(kernel, K, W, U, G, initial_state) with tensor values as positional arguments, which will be misinterpreted as warmup, rep, _n_warmup, _n_repeat, and quantiles parameters with type mismatches.
Both calls need to be refactored to wrap the function invocations appropriately for the new centralized do_bench interface.
🤖 Prompt for AI Agents
In `@examples/gdn/example_chunk_delta_h.py` at line 7, The calls to do_bench must
be adapted to the centralized signature that invokes fn() with no args: wrap the
target functions (chunk_gated_delta_rule_fwd_h and kernel) into zero-argument
callables (e.g., lambda or functools.partial) that capture K, W, U, G,
initial_state and any other inputs, and then call do_bench with explicit
benchmarking parameters (warmup, rep, _n_warmup, _n_repeat, quantiles,
fast_flush, backend, return_mode) rather than passing tensors as
positional/keyword args; update the two sites where do_bench is invoked so they
pass a zero-arg wrapper and appropriate numeric/flag values for the benchmark
options.
* refactor: replace local do_bench function with centralized profiler import - Removed local implementations of the `do_bench` function from multiple example files. - Updated imports to use the centralized `do_bench` function from `tilelang.profiler`, promoting code reuse and consistency across examples. * refactor: remove unused do_bench import from example_chunk_o_bwd.py and __init__.py - Eliminated the unused `do_bench` import from `example_chunk_o_bwd.py` and `tilelang/profiler/__init__.py`, streamlining the code and improving clarity. * lint fix
Summary
do_benchfunction definitions with centralizedfrom tilelang.profiler import do_benchimportfrom triton.testing import do_benchwithfrom tilelang.profiler import do_benchin examplesChanged Files
Import replacement (triton → tilelang.profiler)
examples/gemm_sp/example_gemm_sp.pyexamples/gemm_sp/example_custom_compress.pyRemove local do_bench definition
examples/flash_decoding/example_gqa_decode_varlen_logits.pyexamples/gdn/example_chunk_delta_h.pyexamples/gdn/example_chunk_o_bwd.pyexamples/gdn/example_chunk_delta_bwd.pyTest plan
do_bench🤖 Generated with Claude Code
Summary by CodeRabbit
Refactor
Breaking Changes
✏️ Tip: You can customize this high-level summary in your review settings.