Skip to content

Commit 9a8d2a1

Browse files
minjangliuyunqi20
authored andcommitted
[AUTOTUNER] A quick follow-up for more device-independent do_bench (#4974)
This is a quick follow-up for the recent autotuner/testing changes as in triton-lang/triton#4496. This PR moves the empty cache creation into the driver code to make the code more device independent.
1 parent e61ccff commit 9a8d2a1

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

python/triton/testing.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
9595
return _summarize_statistics(torch.tensor(ret), quantiles, return_mode)
9696

9797

98-
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device_type="cuda"):
98+
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
9999
"""
100100
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
101101
the 20-th and 80-th performance percentile.
@@ -120,11 +120,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
120120
fn()
121121
di.synchronize()
122122

123-
# We maintain a buffer of 256 MB that we clear
124-
# before each kernel call to make sure that the L2 cache
125-
# doesn't contain any input data before the run
126-
cache_size = 256 * 1024 * 1024
127-
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
123+
cache = runtime.driver.active.get_empty_cache_for_benchmark()
128124

129125
# Estimate the runtime of the function
130126
start_event = di.Event(enable_timing=True)

third_party/amd/backend/driver.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,3 +503,10 @@ def get_current_target(self):
503503
def get_benchmarker(self):
504504
from triton.testing import do_bench
505505
return do_bench
506+
507+
def get_empty_cache_for_benchmark(self):
508+
import torch
509+
510+
# It's the same as the Nvidia backend.
511+
cache_size = 256 * 1024 * 1024
512+
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')

third_party/nvidia/backend/driver.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,12 @@ def is_active():
456456
def get_benchmarker(self):
457457
from triton.testing import do_bench
458458
return do_bench
459+
460+
def get_empty_cache_for_benchmark(self):
461+
import torch
462+
463+
# We maintain a buffer of 256 MB that we clear
464+
# before each kernel call to make sure that the L2 cache
465+
# doesn't contain any input data before the run
466+
cache_size = 256 * 1024 * 1024
467+
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')

0 commit comments

Comments
 (0)