From fca4aeb649699b6fec35ed086fae6b66ebee1bd2 Mon Sep 17 00:00:00 2001 From: Artemiy Bulavin Date: Wed, 24 Jul 2024 14:07:32 +0000 Subject: [PATCH] Use torch._dynamo.device_interface in do_bench to make it device agnostic --- python/triton/testing.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/python/triton/testing.py b/python/triton/testing.py index be9f8eac707f..a4e4500226a4 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -90,7 +90,8 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod return _summarize_statistics(torch.tensor(ret), quantiles, return_mode) -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean", + device_type="cuda"): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -113,34 +114,36 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu assert return_mode in ["min", "max", "mean", "median"] import torch + di = torch._dynamo.device_interface.get_interface_for_device(device_type) + fn() - torch.cuda.synchronize() + di.synchronize() # We maintain a buffer of 256 MB that we clear # before each kernel call to make sure that the L2 cache # doesn't contain any input data before the run cache_size = 256 * 1024 * 1024 if fast_flush: - cache = torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') + cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device_type) else: - cache = torch.empty(int(cache_size), dtype=torch.int8, device='cuda') + cache = torch.empty(int(cache_size), dtype=torch.int8, device=device_type) # Estimate the runtime of the function - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) start_event.record() for _ in range(5): cache.zero_() fn() end_event.record() - torch.cuda.synchronize() + di.synchronize() estimate_ms = start_event.elapsed_time(end_event) / 5 # compute number of warmup and repeat n_warmup = max(1, int(warmup / estimate_ms)) n_repeat = max(1, int(rep / estimate_ms)) - start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] # Warm-up for _ in range(n_warmup): fn() @@ -159,7 +162,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu fn() end_event[i].record() # Record clocks - torch.cuda.synchronize() + di.synchronize() times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode)