Skip to content

Commit 72d75e9

Browse files
committed
Make autotuner take do_bench as a parameter
This makes the autotuner device-agnostic. Instead of having to know about the existence of e.g. do_bench_cudagraph, it can let the callers decide which backend-specific benchmarking function to use. See discussion in #4417.
1 parent 1402578 commit 72d75e9

File tree

9 files changed

+76
-34
lines changed

9 files changed

+76
-34
lines changed

python/test/unit/hopper/test_flashattention.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
435435
@triton.testing.perf_report(configs)
436436
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
437437
assert mode in ['fwd', 'bwd']
438-
warmup = 25
439-
rep = 100
440438
if provider == "triton":
441439
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
442440
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
@@ -447,7 +445,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
447445
o = fn()
448446
do = torch.randn_like(o)
449447
fn = lambda: o.backward(do, retain_graph=True)
450-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
448+
ms = triton.testing.do_bench(fn)
451449
return ms
452450
if provider == "flash":
453451
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
@@ -459,7 +457,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
459457
o = fn()
460458
do = torch.randn_like(o)
461459
fn = lambda: o.backward(do, retain_graph=True)
462-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
460+
ms = triton.testing.do_bench(fn)
463461
return ms
464462

465463

python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
361361
@triton.testing.perf_report(configs)
362362
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
363363
assert mode in ['fwd', 'bwd']
364-
# warmup = 25
365-
# rep = 100
366-
warmup = 0
367-
rep = 1
364+
do_bench = lambda kernel, quantiles: triton.testing.do_bench(kernel, warmup=0, rep=1, quantiles=quantiles)
368365
if provider == "triton":
369366
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
370367
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
@@ -375,7 +372,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
375372
o = fn()
376373
do = torch.randn_like(o)
377374
fn = lambda: o.backward(do, retain_graph=True)
378-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
375+
ms = triton.testing.do_bench(fn, do_bench=do_bench)
379376
return ms
380377
if provider == "flash":
381378
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
@@ -387,7 +384,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
387384
o = fn()
388385
do = torch.randn_like(o)
389386
fn = lambda: o.backward(do, retain_graph=True)
390-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
387+
ms = triton.testing.do_bench(fn, do_bench=do_bench)
391388
return ms
392389

393390

python/test/unit/language/test_decorator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ def test_triton_heuristic(device):
3333
src = torch.empty(N, device=device)
3434
dst = torch.zeros(N, device=device)
3535

36-
@triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], warmup=1, rep=1)
36+
do_bench = lambda kernel, quantiles: triton.testing.do_bench(kernel, quantiles=quantiles, warmup=1, rep=1)
37+
38+
@triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], do_bench=do_bench)
3739
@triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs
3840
@triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args
3941
@triton.jit

python/test/unit/runtime/test_autotuner.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@
55
import pytest
66

77

8-
@pytest.mark.parametrize('use_cuda_graph', [False, True])
9-
def test_kwargs(use_cuda_graph: bool, device: str):
8+
def do_bench(kernel_call, quantiles):
9+
return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1)
10+
11+
12+
def test_kwargs(device: str):
1013
N = 1024
1114
src = torch.randn(N, device=device)
1215
dst = torch.empty(N, device=device)
1316

1417
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
1518

16-
@triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph)
19+
@triton.autotune(configs=configs, key=['N'], do_bench=do_bench)
1720
@triton.jit
1821
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
1922
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
@@ -31,7 +34,7 @@ def test_restore(device):
3134

3235
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
3336

34-
@triton.autotune(configs=configs, key=['N'], restore_value=['src'], warmup=1, rep=1)
37+
@triton.autotune(configs=configs, key=['N'], restore_value=['src'], do_bench=do_bench)
3538
@triton.jit
3639
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
3740
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
@@ -61,7 +64,7 @@ def _post_hook(*args, exception):
6164
values["has_exception"] = True
6265
assert values["counter"] == 0
6366

64-
@triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, pre_hook=_pre_hook, post_hook=_post_hook)
67+
@triton.autotune(configs=configs, key=['N'], do_bench=do_bench, pre_hook=_pre_hook, post_hook=_post_hook)
6568
@triton.heuristics({"N_STAGES": lambda nargs: 100 if nargs['N'] == 4096 else 4})
6669
@triton.jit
6770
def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr):
@@ -112,7 +115,7 @@ def perf_model(*args, **kwargs):
112115
else:
113116
prune_configs_by = {'early_config_prune': early_config_prune}
114117

115-
@triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, warmup=1, rep=1)
118+
@triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, do_bench=do_bench)
116119
@triton.jit
117120
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
118121
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

python/triton/backends/driver.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
from abc import ABCMeta, abstractmethod, abstractclassmethod
2+
from typing import Callable, List, Protocol, Sequence
3+
4+
5+
class Benchmarker(Protocol):
6+
7+
def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]:
8+
pass
29

310

411
class DriverBase(metaclass=ABCMeta):
@@ -11,6 +18,13 @@ def is_active(self):
1118
def get_current_target(self):
1219
pass
1320

21+
@abstractmethod
22+
def get_benchmarker(self) -> Benchmarker:
23+
"""
24+
Return the benchmarking function that this backend should use by default.
25+
"""
26+
raise NotImplementedError
27+
1428
def __init__(self) -> None:
1529
pass
1630

python/triton/runtime/autotuner.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import inspect
77
from typing import Dict
88

9-
from ..testing import do_bench, do_bench_cudagraph
109
from .jit import KernelInterface
1110
from .errors import OutOfResources
11+
from .driver import driver
1212

1313

1414
class Autotuner(KernelInterface):
@@ -24,9 +24,10 @@ def __init__(
2424
pre_hook=None,
2525
post_hook=None,
2626
prune_configs_by: Dict = None,
27-
warmup=25,
28-
rep=100,
27+
warmup=None,
28+
rep=None,
2929
use_cuda_graph=False,
30+
do_bench=None,
3031
):
3132
"""
3233
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
@@ -88,10 +89,35 @@ def _post_hook(args, exception):
8889
self.base_fn = fn
8990
while not inspect.isfunction(self.base_fn):
9091
self.base_fn = self.base_fn.fn
91-
self.num_warmups = warmup
92-
self.num_reps = rep
93-
import torch
94-
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
92+
93+
# If we got explicitly called via the old interface, raise a warning
94+
# and proceed with the old behavior.
95+
if warmup is not None or rep is not None or use_cuda_graph:
96+
import warnings
97+
warnings.warn("warmup, rep, and use_cuda_graph parameters are deprecated. See _ for details.",
98+
DeprecationWarning)
99+
import torch
100+
if use_cuda_graph and torch.cuda.is_available():
101+
from ..testing import do_bench_cudagraph
102+
self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
103+
kernel_call,
104+
rep=rep if rep is not None else 100,
105+
quantiles=quantiles,
106+
)
107+
return
108+
109+
import triton.testing
110+
self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
111+
kernel_call,
112+
warmup=warmup if warmup is not None else 25,
113+
rep=rep if rep is not None else 100,
114+
quantiles=quantiles,
115+
)
116+
117+
if do_bench is None:
118+
self.do_bench = driver.active.get_benchmarker()
119+
else:
120+
self.do_bench = do_bench
95121

96122
def _bench(self, *args, config, **meta):
97123
from ..compiler.errors import CompileTimeAssertionFailure
@@ -125,11 +151,7 @@ def kernel_call():
125151
self.post_hook(args, exception=None)
126152

127153
try:
128-
if self.use_cuda_graph:
129-
import torch
130-
with torch.cuda.stream(torch.cuda.Stream()):
131-
return do_bench_cudagraph(kernel_call, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
132-
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
154+
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
133155
except (OutOfResources, CompileTimeAssertionFailure):
134156
return [float("inf"), float("inf"), float("inf")]
135157

@@ -262,7 +284,7 @@ def __str__(self):
262284

263285

264286
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
265-
warmup=25, rep=100, use_cuda_graph=False):
287+
warmup=None, rep=None, use_cuda_graph=False, do_bench=None):
266288
"""
267289
Decorator for auto-tuning a :code:`triton.jit`'d function.
268290

python/tutorials/06-fused-attention.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,6 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
599599
@triton.testing.perf_report(configs)
600600
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"):
601601
assert mode in ["fwd", "bwd"]
602-
warmup = 25
603-
rep = 100
604602
dtype = torch.float16
605603
if "triton" in provider:
606604
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
@@ -618,15 +616,15 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev
618616
o = fn()
619617
do = torch.randn_like(o)
620618
fn = lambda: o.backward(do, retain_graph=True)
621-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
619+
ms = triton.testing.do_bench(fn)
622620
if provider == "flash":
623621
qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
624622
fn = lambda: flash_attn_func(qkv, causal=causal)
625623
if mode == "bwd":
626624
o = fn()
627625
do = torch.randn_like(o)
628626
fn = lambda: o.backward(do, retain_graph=True)
629-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
627+
ms = triton.testing.do_bench(fn)
630628
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM
631629
total_flops = 2 * flops_per_matmul
632630
if causal:

third_party/amd/backend/driver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,3 +440,7 @@ def get_current_target(self):
440440
arch = device_properties['arch']
441441
warp_size = device_properties['warpSize']
442442
return GPUTarget("hip", arch.split(':')[0], warp_size)
443+
444+
def get_benchmarker(self):
445+
from triton.testing import do_bench
446+
return do_bench

third_party/nvidia/backend/driver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,3 +383,7 @@ def get_current_target(self):
383383
def is_active():
384384
import torch
385385
return torch.cuda.is_available() and (torch.version.hip is None)
386+
387+
def get_benchmarker(self):
388+
from triton.testing import do_bench
389+
return do_bench

0 commit comments

Comments
 (0)