diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index fbf65d9e908f..847201d61bee 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -28,6 +28,7 @@ def _find_concrete_subclasses(module, base_class): @dataclass(frozen=True) class Backend: + name: str = "" compiler: BaseBackend = None driver: DriverBase = None @@ -42,7 +43,7 @@ def _discover_backends(): continue compiler = _load_module(name, os.path.join(root, name, 'compiler.py')) driver = _load_module(name, os.path.join(root, name, 'driver.py')) - backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), + backends[name] = Backend(name, _find_concrete_subclasses(compiler, BaseBackend), _find_concrete_subclasses(driver, DriverBase)) return backends diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py index 4f62a97942d4..e66442943b03 100644 --- a/python/triton/backends/driver.py +++ b/python/triton/backends/driver.py @@ -32,14 +32,3 @@ def __init__(self): # TODO: remove once TMA is cleaned up def assemble_tensormap_to_arg(self, tensormaps_info, args): return args - - -class CPUDriverBase(DriverBase): - - def __init__(self): - # Right now, we just provide dummy functions. - # TODO: Consider better engineering the code only intended for GPU in jit.py. - self.get_device_capability = lambda idx: (0, 0) - self.get_current_stream = lambda idx: 0 - self.get_current_device = lambda: 0 - self.set_current_device = lambda idx: None diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 4cf1aea8e494..ed3c16978bd2 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -66,5 +66,22 @@ def set_active(self, driver: DriverBase): def reset_active(self): self.active = self.default + def set_active_to_cpu(self): + if "cpu" not in backends: + raise RuntimeError("CPU backend is unavailable") + self.active = backends["cpu"].driver() + + def set_active_to_gpu(self): + active_gpus = [(name, backend.driver) + for name, backend in backends.items() + if backend.driver.is_active() and name != "cpu"] + if len(active_gpus) != 1: + raise RuntimeError(f"{len(active_gpus)} active GPU drivers ({active_gpus}). There should only be one GPU.") + self.active = active_gpus[0][1]() + return active_gpus[0][0] + + def get_active_gpus(self): + return [name for name, backend in backends.items() if backend.driver.is_active() and name != "cpu"] + driver = DriverConfig() diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index a12b1d235b7c..0a1b08601b44 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -606,6 +606,7 @@ def run(self, *args, grid, warmup, **kwargs): # parse options device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) + target = driver.active.get_current_target() kwargs["debug"] = self.debug # Execute pre run hooks with args and kwargs @@ -618,12 +619,12 @@ def run(self, *args, grid, warmup, **kwargs): bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) # compute cache key + device_key = f"{target.backend}:{device}" key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) - kernel = self.cache[device].get(key, None) + kernel = self.cache[device_key].get(key, None) if kernel is None: # Kernel is not cached; we have to compile. - target = driver.active.get_current_target() backend = self.make_backend(target) options = backend.parse_options(kwargs) @@ -664,7 +665,7 @@ def run(self, *args, grid, warmup, **kwargs): target=target, options=options.__dict__, ) - self.cache[device][key] = kernel + self.cache[device_key][key] = kernel # Check that used global values have not changed. not_present = object() diff --git a/python/triton/testing.py b/python/triton/testing.py index 0b228871c9f5..c03c1ae57811 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -2,11 +2,34 @@ import os import subprocess import sys +import time from contextlib import contextmanager from typing import Any, Dict, List from . import language as tl +class Event: + + def __init__(self, is_cpu): + self.time = 0 + self.is_cpu = is_cpu + if not is_cpu: + import torch + self.cuda_event = torch.cuda.Event(enable_timing=True) + + def elapsed_time(self, end_event) -> float: + if self.is_cpu: + return (end_event.time - self.time) * 1000 + else: + return self.cuda_event.elapsed_time(end_event.cuda_event) + + def record(self): + if self.is_cpu: + self.time = time.perf_counter() + else: + self.cuda_event.record() + + def nvsmi(attrs): attrs = ','.join(attrs) cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] @@ -79,7 +102,8 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"): return getattr(torch, return_mode)(times).item() -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean", + is_cpu=False): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -101,32 +125,42 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu import torch fn() - torch.cuda.synchronize() + if not is_cpu: + torch.cuda.synchronize() + + if not is_cpu: + cache_size = 256e6 + device = 'cuda' + else: + # Currently, a typical L3 cache size for high-end server CPUs are ~400MB. + cache_size = 512e6 + device = 'cpu' # We maintain a buffer of 256 MB that we clear # before each kernel call to make sure that the L2 # doesn't contain any input data before the run if fast_flush: - cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) else: - cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') + cache = torch.empty(int(cache_size), dtype=torch.int8, device=device) # Estimate the runtime of the function - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + start_event = Event(is_cpu) + end_event = Event(is_cpu) start_event.record() for _ in range(5): cache.zero_() fn() end_event.record() - torch.cuda.synchronize() + if not is_cpu: + torch.cuda.synchronize() estimate_ms = start_event.elapsed_time(end_event) / 5 # compute number of warmup and repeat n_warmup = max(1, int(warmup / estimate_ms)) n_repeat = max(1, int(rep / estimate_ms)) - start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + start_event = [Event(is_cpu) for i in range(n_repeat)] + end_event = [Event(is_cpu) for i in range(n_repeat)] # Warm-up for _ in range(n_warmup): fn() @@ -145,7 +179,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu fn() end_event[i].record() # Record clocks - torch.cuda.synchronize() + if not is_cpu: + torch.cuda.synchronize() times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) if quantiles is not None: ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 1c1900a07481..57e8b0996a6b 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -23,6 +23,8 @@ import triton import triton.language as tl +BLOCK_SIZE = 1024 + @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. @@ -57,10 +59,10 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. # and (2) enqueue the above kernel with appropriate grid/block sizes: -def add(x: torch.Tensor, y: torch.Tensor): +def add(x: torch.Tensor, y: torch.Tensor, is_cpu): # We need to preallocate the output. output = torch.empty_like(x) - assert x.is_cuda and y.is_cuda and output.is_cuda + assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu n_elements = output.numel() # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. @@ -78,17 +80,37 @@ def add(x: torch.Tensor, y: torch.Tensor): # %% # We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: - torch.manual_seed(0) size = 98432 -x = torch.rand(size, device='cuda') -y = torch.rand(size, device='cuda') -output_torch = x + y -output_triton = add(x, y) -print(output_torch) -print(output_triton) -print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}') + +triton.runtime.driver.set_active_to_cpu() +x = torch.rand(size, device='cpu') +y = torch.rand(size, device='cpu') +output_torch_cpu = x + y +output_triton_cpu = add(x, y, is_cpu=True) +print(output_torch_cpu) +print(output_triton_cpu) +print(f'The maximum difference between torch-cpu and triton-cpu is ' + f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') + +LINE_VALS = ['triton-cpu', 'torch-cpu'] +LINE_NAMES = ['TritonCPU', 'TorchCPU'] +LINE_STYLES = [('blue', '-'), ('green', '-')] + +if triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + x = x.to('cuda') + y = y.to('cuda') + output_torch_gpu = x + y + output_triton_gpu = add(x, y, is_cpu=False) + print(output_torch_gpu) + print(output_triton_gpu) + print(f'The maximum difference between torch-gpu and triton-gpu is ' + f'{torch.max(torch.abs(output_torch_gpu - output_triton_gpu))}') + + LINE_VALS += ['triton-gpu', 'torch-gpu'] + LINE_NAMES += ['TritonGPU', 'TorchGPU'] + LINE_STYLES += [('yellow', '-'), ('red', '-')] # %% # Seems like we're good to go! @@ -108,21 +130,34 @@ def add(x: torch.Tensor, y: torch.Tensor): x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`. x_log=True, # x axis is logarithmic. line_arg='provider', # Argument name whose value corresponds to a different line in the plot. - line_vals=['triton', 'torch'], # Possible values for `line_arg`. - line_names=['Triton', 'Torch'], # Label name for the lines. - styles=[('blue', '-'), ('green', '-')], # Line styles. + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. ylabel='GB/s', # Label name for the y-axis. - plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot. + plot_name= + # Name for the plot. Used also as a file name for saving the plot. + f'vector-add-performance (BLOCK_SIZE={BLOCK_SIZE})', args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(size, provider): - x = torch.rand(size, device='cuda', dtype=torch.float32) - y = torch.rand(size, device='cuda', dtype=torch.float32) + device = 'cpu' if 'cpu' in provider else 'cuda' + x = torch.rand(size, device=device, dtype=torch.float32) + y = torch.rand(size, device=device, dtype=torch.float32) + + if device == 'cpu': + triton.runtime.driver.set_active_to_cpu() + else: + triton.runtime.driver.set_active_to_gpu() + quantiles = [0.5, 0.2, 0.8] - if provider == 'torch': + if provider == 'torch-gpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) - if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles) + elif provider == 'triton-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, False), quantiles=quantiles) + elif provider == 'torch-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True) + elif provider == 'triton-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True) gbps = lambda ms: 12 * size / ms * 1e-6 return gbps(ms), gbps(max_ms), gbps(min_ms)