Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
return _summarize_statistics(torch.tensor(ret), quantiles, return_mode)


def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
device_type="cuda"):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand All @@ -113,34 +114,36 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
assert return_mode in ["min", "max", "mean", "median"]
import torch

di = torch._dynamo.device_interface.get_interface_for_device(device_type)

fn()
torch.cuda.synchronize()
di.synchronize()

# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2 cache
# doesn't contain any input data before the run
cache_size = 256 * 1024 * 1024
if fast_flush:
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device_type)
else:
cache = torch.empty(int(cache_size), dtype=torch.int8, device='cuda')
cache = torch.empty(int(cache_size), dtype=torch.int8, device=device_type)

# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event = di.Event(enable_timing=True)
end_event = di.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
di.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
start_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()
Expand All @@ -159,7 +162,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
fn()
end_event[i].record()
# Record clocks
torch.cuda.synchronize()
di.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)

Expand Down