Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/test/unit/hopper/test_flashattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
Expand All @@ -447,7 +445,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
ms = triton.testing.do_bench(fn)
return ms
if provider == "flash":
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
Expand All @@ -459,7 +457,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
ms = triton.testing.do_bench(fn)
return ms


Expand Down
4 changes: 3 additions & 1 deletion python/test/unit/language/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def test_triton_heuristic(device):
src = torch.empty(N, device=device)
dst = torch.zeros(N, device=device)

@triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], warmup=1, rep=1)
do_bench = lambda kernel, quantiles: triton.testing.do_bench(kernel, quantiles=quantiles, warmup=1, rep=1)

@triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], do_bench=do_bench)
@triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs
@triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args
@triton.jit
Expand Down
15 changes: 9 additions & 6 deletions python/test/unit/runtime/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
import pytest


@pytest.mark.parametrize('use_cuda_graph', [False, True])
def test_kwargs(use_cuda_graph: bool, device: str):
def do_bench(kernel_call, quantiles):
return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1)


def test_kwargs(device: str):
N = 1024
src = torch.randn(N, device=device)
dst = torch.empty(N, device=device)

configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]

@triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph)
@triton.autotune(configs=configs, key=['N'], do_bench=do_bench)
@triton.jit
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
Expand All @@ -31,7 +34,7 @@ def test_restore(device):

configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]

@triton.autotune(configs=configs, key=['N'], restore_value=['src'], warmup=1, rep=1)
@triton.autotune(configs=configs, key=['N'], restore_value=['src'], do_bench=do_bench)
@triton.jit
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
Expand Down Expand Up @@ -61,7 +64,7 @@ def _post_hook(*args, exception):
values["has_exception"] = True
assert values["counter"] == 0

@triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, pre_hook=_pre_hook, post_hook=_post_hook)
@triton.autotune(configs=configs, key=['N'], do_bench=do_bench, pre_hook=_pre_hook, post_hook=_post_hook)
@triton.heuristics({"N_STAGES": lambda nargs: 100 if nargs['N'] == 4096 else 4})
@triton.jit
def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr):
Expand Down Expand Up @@ -112,7 +115,7 @@ def perf_model(*args, **kwargs):
else:
prune_configs_by = {'early_config_prune': early_config_prune}

@triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, warmup=1, rep=1)
@triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, do_bench=do_bench)
@triton.jit
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
Expand Down
14 changes: 14 additions & 0 deletions python/triton/backends/driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from abc import ABCMeta, abstractmethod, abstractclassmethod
from typing import Callable, List, Protocol, Sequence


class Benchmarker(Protocol):

def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]:
pass


class DriverBase(metaclass=ABCMeta):
Expand All @@ -11,6 +18,13 @@ def is_active(self):
def get_current_target(self):
pass

@abstractmethod
def get_benchmarker(self) -> Benchmarker:
"""
Return the benchmarking function that this backend should use by default.
"""
raise NotImplementedError

def __init__(self) -> None:
pass

Expand Down
46 changes: 35 additions & 11 deletions python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import inspect
from typing import Dict

from ..testing import do_bench, do_bench_cudagraph
from .jit import KernelInterface
from .errors import OutOfResources
from .driver import driver


class Autotuner(KernelInterface):
Expand All @@ -24,9 +24,10 @@ def __init__(
pre_hook=None,
post_hook=None,
prune_configs_by: Dict = None,
warmup=25,
rep=100,
warmup=None,
rep=None,
use_cuda_graph=False,
do_bench=None,
):
"""
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
Expand Down Expand Up @@ -88,10 +89,35 @@ def _post_hook(args, exception):
self.base_fn = fn
while not inspect.isfunction(self.base_fn):
self.base_fn = self.base_fn.fn
self.num_warmups = warmup
self.num_reps = rep
import torch
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
Comment on lines -91 to -94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Jokeren, @int3,

Fields self.num_warmups , self.num_reps and self.use_cuda_graph are used by PyTorch to find out what parameters the autotuner was called with:

https://github.com/pytorch/pytorch/blame/5141ade8e30c64e873e14dcc8de233da45d15025/torch/_higher_order_ops/triton_kernel_wrap.py#L829

Can they be left until the corresponding parameters are removed from __init__ signature?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@int3 is driving the effort. It's up to him. I'm OK either way.


# If we got explicitly called via the old interface, raise a warning
# and proceed with the old behavior.
if warmup is not None or rep is not None or use_cuda_graph:
import warnings
warnings.warn("warmup, rep, and use_cuda_graph parameters are deprecated. See _ for details.",
DeprecationWarning)
import torch
if use_cuda_graph and torch.cuda.is_available():
from ..testing import do_bench_cudagraph
self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
kernel_call,
rep=rep if rep is not None else 100,
quantiles=quantiles,
)
return

import triton.testing
self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
kernel_call,
warmup=warmup if warmup is not None else 25,
rep=rep if rep is not None else 100,
quantiles=quantiles,
)

if do_bench is None:
self.do_bench = driver.active.get_benchmarker()
else:
self.do_bench = do_bench

def _bench(self, *args, config, **meta):
from ..compiler.errors import CompileTimeAssertionFailure
Expand Down Expand Up @@ -125,9 +151,7 @@ def kernel_call():
self.post_hook(args, exception=None)

try:
if self.use_cuda_graph:
return do_bench_cudagraph(kernel_call, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
except (OutOfResources, CompileTimeAssertionFailure):
return [float("inf"), float("inf"), float("inf")]

Expand Down Expand Up @@ -260,7 +284,7 @@ def __str__(self):


def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
warmup=25, rep=100, use_cuda_graph=False):
warmup=None, rep=None, use_cuda_graph=False, do_bench=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.

Expand Down
6 changes: 2 additions & 4 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,6 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"):
assert mode in ["fwd", "bwd"]
warmup = 25
rep = 100
dtype = torch.float16
if "triton" in provider:
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
Expand All @@ -618,15 +616,15 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
ms = triton.testing.do_bench(fn)
if provider == "flash":
qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, causal=causal)
if mode == "bwd":
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
ms = triton.testing.do_bench(fn)
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM
total_flops = 2 * flops_per_matmul
if causal:
Expand Down
4 changes: 4 additions & 0 deletions third_party/amd/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,7 @@ def get_current_target(self):
arch = device_properties['arch']
warp_size = device_properties['warpSize']
return GPUTarget("hip", arch.split(':')[0], warp_size)

def get_benchmarker(self):
from triton.testing import do_bench
return do_bench
4 changes: 4 additions & 0 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,7 @@ def get_current_target(self):
def is_active():
import torch
return torch.cuda.is_available() and (torch.version.hip is None)

def get_benchmarker(self):
from triton.testing import do_bench
return do_bench