-
Notifications
You must be signed in to change notification settings - Fork 449
Refactor: Use centralized do_bench from tilelang.profiler #1670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
|
|
||
| import tilelang | ||
| import tilelang.language as T | ||
| from tilelang.profiler import do_bench | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: The import change introduces a breaking issue. At lines 480-481, fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)The centralized The calls should wrap the function and its arguments in a lambda or 🐛 Proposed fix- fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
- tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)
+ fla_time = do_bench(lambda: chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size))
+ tilelang_time = do_bench(lambda: kernel(Q, K, W, G, h0, dht, dO, dv))🤖 Prompt for AI Agents |
||
|
|
||
| print(tilelang.__file__, flush=True) | ||
|
|
||
|
|
@@ -544,31 +545,6 @@ def run_test( | |
| assert_similar(dv2_ref_torch, dv2_tilelang, 1e-5, "torch-tilelang", data="dv2") | ||
|
|
||
|
|
||
| def do_bench(fn, *args, warmup=10, rep=10, **kwargs): | ||
| """ | ||
| Do benchmark for a function. | ||
| """ | ||
| start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] | ||
| end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] | ||
| for _ in range(warmup): | ||
| fn(*args, **kwargs) | ||
|
|
||
| torch.cuda.synchronize() | ||
| for i in range(rep): | ||
| start_event[i].record() | ||
| fn(*args, **kwargs) | ||
| end_event[i].record() | ||
| torch.cuda.synchronize() | ||
|
|
||
| # Record clocks | ||
| times = torch.tensor( | ||
| [s.elapsed_time(e) for s, e in zip(start_event, end_event)], | ||
| dtype=torch.float, | ||
| ) | ||
|
|
||
| return times.mean().item() | ||
|
|
||
|
|
||
| def main(): | ||
| DK = 128 | ||
| run_test( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| import tilelang | ||
| import tilelang.language as T | ||
| from tilelang.autotuner import autotune | ||
| from tilelang.profiler import do_bench | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# First, find the profiler module and examine do_bench signature
find . -type f -name "bench.py" | grep -E "profiler|tilelang"Repository: tile-ai/tilelang Length of output: 88 🏁 Script executed: #!/bin/bash
# Explore the repository structure
ls -laRepository: tile-ai/tilelang Length of output: 2689 🏁 Script executed: #!/bin/bash
# Check if tilelang/profiler exists and what's in it
find . -type d -name "profiler" 2>/dev/null | head -5Repository: tile-ai/tilelang Length of output: 105 🏁 Script executed: #!/bin/bash
# Find all Python files in tilelang directory
find . -path "./tilelang/profiler*" -type f -name "*.py" 2>/dev/nullRepository: tile-ai/tilelang Length of output: 120 🏁 Script executed: #!/bin/bash
# Examine the example file to see the actual do_bench usage
if [ -f "examples/gdn/example_chunk_delta_h.py" ]; then
echo "=== Lines 300-320 of example_chunk_delta_h.py ==="
sed -n '300,320p' examples/gdn/example_chunk_delta_h.py
fiRepository: tile-ai/tilelang Length of output: 830 🏁 Script executed: #!/bin/bash
# Examine the do_bench function signature in bench.py
cat -n ./tilelang/profiler/bench.py | head -50Repository: tile-ai/tilelang Length of output: 2029 🏁 Script executed: #!/bin/bash
# Get the full function signature
ast-grep --pattern $'def do_bench($$$):
$$$'Repository: tile-ai/tilelang Length of output: 42 🏁 Script executed: #!/bin/bash
# Continue reading bench.py to find do_bench function
cat -n ./tilelang/profiler/bench.py | tail -n +50 | head -100Repository: tile-ai/tilelang Length of output: 4530 🏁 Script executed: #!/bin/bash
# Search for the do_bench function definition
rg "def do_bench" -A 10 ./tilelang/profiler/bench.pyRepository: tile-ai/tilelang Length of output: 399 🏁 Script executed: #!/bin/bash
# Check git history to see the previous version
git log --oneline -5 HEADRepository: tile-ai/tilelang Length of output: 151 🏁 Script executed: #!/bin/bash
# Get the previous commit and show the file from before
git show HEAD~1:examples/gdn/example_chunk_delta_h.py 2>/dev/null | head -320 | tail -50Repository: tile-ai/tilelang Length of output: 42 🏁 Script executed: #!/bin/bash
# Check the diff of the current file in HEAD
git diff HEAD~1 examples/gdn/example_chunk_delta_h.py 2>/dev/null | head -100Repository: tile-ai/tilelang Length of output: 42 🏁 Script executed: #!/bin/bash
# Check git log to understand the history better
git log --oneline examples/gdn/example_chunk_delta_h.py 2>/dev/null | head -10Repository: tile-ai/tilelang Length of output: 151 🏁 Script executed: #!/bin/bash
# Check what files were changed in the HEAD commit
git show --stat HEADRepository: tile-ai/tilelang Length of output: 50372 🏁 Script executed: #!/bin/bash
# Show the diff of the specific file
git show HEAD:examples/gdn/example_chunk_delta_h.py 2>/dev/null | sed -n '1,20p'Repository: tile-ai/tilelang Length of output: 669 🏁 Script executed: #!/bin/bash
# Check the previous commit for that file
git show HEAD~1:examples/gdn/example_chunk_delta_h.py 2>/dev/null | sed -n '1,320p' | tail -50Repository: tile-ai/tilelang Length of output: 42 🏁 Script executed: #!/bin/bash
# Check what the previous version of the file looked like (look further back in history)
git log --all --pretty=oneline examples/gdn/example_chunk_delta_h.py 2>/dev/null | head -20Repository: tile-ai/tilelang Length of output: 1134 🏁 Script executed: #!/bin/bash
# Look at the previous commit to see if there was a local do_bench definition
git show 0f980f15c575bf35db73a70fc04a8a53c005b2c8:examples/gdn/example_chunk_delta_h.py 2>/dev/null | sed -n '1,50p'Repository: tile-ai/tilelang Length of output: 1466 🏁 Script executed: #!/bin/bash
# Check if the previous version had a local do_bench function
git show 0f980f15c575bf35db73a70fc04a8a53c005b2c8:examples/gdn/example_chunk_delta_h.py 2>/dev/null | grep -A 30 "def do_bench"Repository: tile-ai/tilelang Length of output: 778 Fix The import change is correct, but the calls at lines 302-313 are incompatible with the centralized The previous local implementation accepted Lines 302-312 call Line 313 calls Both calls need to be refactored to wrap the function invocations appropriately for the new centralized 🤖 Prompt for AI Agents |
||
|
|
||
| # Add your fla repository path to sys.path | ||
| # Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae | ||
|
|
@@ -224,31 +225,6 @@ def kernel( | |
| return kernel | ||
|
|
||
|
|
||
| def do_bench(fn, *args, warmup=10, rep=10, **kwargs): | ||
| """ | ||
| Do benchmark for a function. | ||
| """ | ||
| start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] | ||
| end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] | ||
| for _ in range(warmup): | ||
| fn(*args, **kwargs) | ||
|
|
||
| torch.cuda.synchronize() | ||
| for i in range(rep): | ||
| start_event[i].record() | ||
| fn(*args, **kwargs) | ||
| end_event[i].record() | ||
| torch.cuda.synchronize() | ||
|
|
||
| # Record clocks | ||
| times = torch.tensor( | ||
| [s.elapsed_time(e) for s, e in zip(start_event, end_event)], | ||
| dtype=torch.float, | ||
| ) | ||
|
|
||
| return times.mean().item() | ||
|
|
||
|
|
||
| def run_test( | ||
| B, | ||
| S, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 3149
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 1305
Critical:
do_benchcalls will fail due to signature incompatibility.The imported
tilelang.profiler.do_benchfunction signature only acceptsfnand keyword arguments (warmup,rep,_n_warmup,_n_repeat,quantiles,fast_flush,backend,return_mode). It does not support the*argscalling convention used in this file.The calls at lines 687-700 and 709-720 pass 11-12 positional arguments after the function:
This will raise
TypeError: got unexpected positional argumentsat runtime. Wrap each function call withfunctools.partialto bind the arguments, or refactor to use keyword arguments if the benchmarked function supports them.🤖 Prompt for AI Agents