Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/test/unit/runtime/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,26 @@ def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLO
assert len(_kernel.cache) == 2


def test_no_do_bench(device: str):
M, N = 1024, 16
src = torch.randn(M * N, device=device)
dst = torch.empty(M * N, device=device)

configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})]

@triton.autotune(configs=configs, key=["M"])
@triton.jit
def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr):
offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M)
offsets_n = tl.arange(0, BLOCK_SIZE_N)
x = tl.load(src + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :])
tl.store(dst + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :], x)

grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE_M']), )
_kernel[grid](dst, src, N, M, N)
assert len(_kernel.cache) == 1


@pytest.mark.parametrize('pass_kwargs_to_kernel', [False, True])
def test_restore(pass_kwargs_to_kernel, device):
N = 1024
Expand Down
15 changes: 9 additions & 6 deletions python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
import hashlib
import json
from functools import cached_property
from typing import Dict, Tuple, List, Optional

from .jit import KernelInterface
Expand Down Expand Up @@ -84,6 +85,7 @@ def _post_hook(kwargs, exception):
while not inspect.isfunction(self.base_fn):
self.base_fn = self.base_fn.fn

self._do_bench = do_bench
self.num_warmups = warmup
self.num_reps = rep
self.use_cuda_graph = use_cuda_graph
Expand All @@ -97,26 +99,27 @@ def _post_hook(kwargs, exception):
stacklevel=1)
if use_cuda_graph:
from ..testing import do_bench_cudagraph
self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
kernel_call,
rep=rep if rep is not None else 100,
quantiles=quantiles,
)
return

import triton.testing
self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
kernel_call,
warmup=warmup if warmup is not None else 25,
rep=rep if rep is not None else 100,
quantiles=quantiles,
)
return

if do_bench is None:
self.do_bench = driver.active.get_benchmarker()
else:
self.do_bench = do_bench
@cached_property
def do_bench(self):
if self._do_bench is None:
return driver.active.get_benchmarker()
return self._do_bench

def _bench(self, *args, config, **meta):
from ..compiler.errors import CompileTimeAssertionFailure
Expand Down
Loading