diff --git a/python/test/unit/hopper/test_flashattention.py b/python/test/unit/hopper/test_flashattention.py index fc8db664c9f0..5053cfc4b245 100644 --- a/python/test/unit/hopper/test_flashattention.py +++ b/python/test/unit/hopper/test_flashattention.py @@ -435,8 +435,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): @triton.testing.perf_report(configs) def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"): assert mode in ['fwd', 'bwd'] - warmup = 25 - rep = 100 if provider == "triton": q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) @@ -447,7 +445,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + ms = triton.testing.do_bench(fn) return ms if provider == "flash": lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) @@ -459,7 +457,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + ms = triton.testing.do_bench(fn) return ms diff --git a/python/test/unit/language/test_decorator.py b/python/test/unit/language/test_decorator.py index 66371ba6003a..fbbfb7144680 100644 --- a/python/test/unit/language/test_decorator.py +++ b/python/test/unit/language/test_decorator.py @@ -33,7 +33,9 @@ def test_triton_heuristic(device): src = torch.empty(N, device=device) dst = torch.zeros(N, device=device) - @triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], warmup=1, rep=1) + do_bench = lambda kernel, quantiles: triton.testing.do_bench(kernel, quantiles=quantiles, warmup=1, rep=1) + + @triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], do_bench=do_bench) @triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs @triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args @triton.jit diff --git a/python/test/unit/runtime/test_autotuner.py b/python/test/unit/runtime/test_autotuner.py index 7d7867a2cad6..456ebf113792 100644 --- a/python/test/unit/runtime/test_autotuner.py +++ b/python/test/unit/runtime/test_autotuner.py @@ -5,6 +5,10 @@ import pytest +def do_bench(kernel_call, quantiles): + return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1) + + @pytest.mark.parametrize('use_cuda_graph', [False, True]) def test_kwargs(use_cuda_graph: bool, device: str): M, N = 1024, 16 @@ -13,7 +17,7 @@ def test_kwargs(use_cuda_graph: bool, device: str): configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})] - @triton.autotune(configs=configs, key=['M'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph) + @triton.autotune(configs=configs, key=['M'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph, do_bench=do_bench) @triton.jit def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M) @@ -34,7 +38,7 @@ def test_restore(device): configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] - @triton.autotune(configs=configs, key=['N'], restore_value=['src'], warmup=1, rep=1) + @triton.autotune(configs=configs, key=['N'], restore_value=['src'], do_bench=do_bench) @triton.jit def _kernel(src, N, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -64,7 +68,7 @@ def _post_hook(*args, exception): values["has_exception"] = True assert values["counter"] == 0 - @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, pre_hook=_pre_hook, post_hook=_post_hook) + @triton.autotune(configs=configs, key=['N'], do_bench=do_bench, pre_hook=_pre_hook, post_hook=_post_hook) @triton.heuristics({"N_STAGES": lambda nargs: 100 if nargs['N'] == 4096 else 4}) @triton.jit def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr): @@ -115,7 +119,7 @@ def perf_model(*args, **kwargs): else: prune_configs_by = {'early_config_prune': early_config_prune} - @triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, warmup=1, rep=1) + @triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, do_bench=do_bench) @triton.jit def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py index e66442943b03..202ae15686d4 100644 --- a/python/triton/backends/driver.py +++ b/python/triton/backends/driver.py @@ -1,4 +1,11 @@ from abc import ABCMeta, abstractmethod, abstractclassmethod +from typing import Callable, List, Protocol, Sequence + + +class Benchmarker(Protocol): + + def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]: + pass class DriverBase(metaclass=ABCMeta): @@ -11,6 +18,13 @@ def is_active(self): def get_current_target(self): pass + @abstractmethod + def get_benchmarker(self) -> Benchmarker: + """ + Return the benchmarking function that this backend should use by default. + """ + raise NotImplementedError + def __init__(self) -> None: pass diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 2b6a7ba32cb8..5f846de17017 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -6,9 +6,9 @@ import inspect from typing import Dict -from ..testing import do_bench, do_bench_cudagraph from .jit import KernelInterface from .errors import OutOfResources +from .driver import driver class Autotuner(KernelInterface): @@ -24,9 +24,10 @@ def __init__( pre_hook=None, post_hook=None, prune_configs_by: Dict = None, - warmup=25, - rep=100, + warmup=None, + rep=None, use_cuda_graph=False, + do_bench=None, ): """ :param prune_configs_by: a dict of functions that are used to prune configs, fields: @@ -88,10 +89,36 @@ def _post_hook(args, exception): self.base_fn = fn while not inspect.isfunction(self.base_fn): self.base_fn = self.base_fn.fn - self.num_warmups = warmup - self.num_reps = rep - import torch - self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available() + + # If we got explicitly called via the old interface, raise a warning + # and proceed with the old behavior. + if warmup is not None or rep is not None or use_cuda_graph: + import warnings + warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " + "https://github.com/triton-lang/triton/pull/4496 for details."), DeprecationWarning, + stacklevel=1) + if use_cuda_graph: + from ..testing import do_bench_cudagraph + self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( + kernel_call, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + import triton.testing + self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( + kernel_call, + warmup=warmup if warmup is not None else 25, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + if do_bench is None: + self.do_bench = driver.active.get_benchmarker() + else: + self.do_bench = do_bench def _bench(self, *args, config, **meta): from ..compiler.errors import CompileTimeAssertionFailure @@ -125,9 +152,7 @@ def kernel_call(): self.post_hook(args, exception=None) try: - if self.use_cuda_graph: - return do_bench_cudagraph(kernel_call, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) - return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) except (OutOfResources, CompileTimeAssertionFailure): return [float("inf"), float("inf"), float("inf")] @@ -257,7 +282,7 @@ def __str__(self): def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, - warmup=25, rep=100, use_cuda_graph=False): + warmup=None, rep=None, use_cuda_graph=False, do_bench=None): """ Decorator for auto-tuning a :code:`triton.jit`'d function. @@ -305,10 +330,12 @@ def kernel(x_ptr, x_size, **META): 'args': a list of arguments passed to the kernel. 'exception': the exception raised by the kernel in case of a compilation or runtime error. :type post_hook: lambda args, exception - :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. + :param warmup: warmup time (in ms) to pass to benchmarking (deprecated). :type warmup: int - :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. + :param rep: repetition time (in ms) to pass to benchmarking (deprecated). :type rep: int + :param do_bench: a benchmark function to measure the time of each run. + :type do_bench: lambda fn, quantiles """ def decorator(fn): diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 45ecc11f4b63..09efc06de4b9 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -601,8 +601,6 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): @triton.testing.perf_report(configs) def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): assert mode in ["fwd", "bwd"] - warmup = 25 - rep = 100 dtype = torch.float16 if "triton" in provider: q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) @@ -620,7 +618,7 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + ms = triton.testing.do_bench(fn) if provider == "flash": qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, causal=causal) @@ -628,7 +626,7 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + ms = triton.testing.do_bench(fn) flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM total_flops = 2 * flops_per_matmul if causal: diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 80ba6d2a5ccc..86c9dd4339f6 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -499,3 +499,7 @@ def get_current_target(self): arch = device_properties['arch'] warp_size = device_properties['warpSize'] return GPUTarget("hip", arch.split(':')[0], warp_size) + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 57c8844e1f36..286f8cb52a5b 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -448,3 +448,7 @@ def get_device_interface(self): def is_active(): import torch return torch.cuda.is_available() and (torch.version.hip is None) + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench