Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/triton/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def _find_concrete_subclasses(module, base_class):

@dataclass(frozen=True)
class Backend:
name: str = ""
compiler: BaseBackend = None
driver: DriverBase = None

Expand All @@ -42,7 +43,7 @@ def _discover_backends():
continue
compiler = _load_module(name, os.path.join(root, name, 'compiler.py'))
driver = _load_module(name, os.path.join(root, name, 'driver.py'))
backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend),
backends[name] = Backend(name, _find_concrete_subclasses(compiler, BaseBackend),
_find_concrete_subclasses(driver, DriverBase))
return backends

Expand Down
11 changes: 0 additions & 11 deletions python/triton/backends/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,3 @@ def __init__(self):
# TODO: remove once TMA is cleaned up
def assemble_tensormap_to_arg(self, tensormaps_info, args):
return args

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was from my scratch. Unused. Removing.


class CPUDriverBase(DriverBase):

def __init__(self):
# Right now, we just provide dummy functions.
# TODO: Consider better engineering the code only intended for GPU in jit.py.
self.get_device_capability = lambda idx: (0, 0)
self.get_current_stream = lambda idx: 0
self.get_current_device = lambda: 0
self.set_current_device = lambda idx: None
17 changes: 17 additions & 0 deletions python/triton/runtime/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,22 @@ def set_active(self, driver: DriverBase):
def reset_active(self):
self.active = self.default

def set_active_to_cpu(self):
if "cpu" not in backends:
raise RuntimeError("CPU backend is unavailable")
self.active = backends["cpu"].driver()

def set_active_to_gpu(self):
active_gpus = [(name, backend.driver)
for name, backend in backends.items()
if backend.driver.is_active() and name != "cpu"]
if len(active_gpus) != 1:
raise RuntimeError(f"{len(active_gpus)} active GPU drivers ({active_gpus}). There should only be one GPU.")
self.active = active_gpus[0][1]()
return active_gpus[0][0]

def get_active_gpus(self):
return [name for name, backend in backends.items() if backend.driver.is_active() and name != "cpu"]


driver = DriverConfig()
7 changes: 4 additions & 3 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def run(self, *args, grid, warmup, **kwargs):
# parse options
device = driver.active.get_current_device()
stream = driver.active.get_current_stream(device)
target = driver.active.get_current_target()
kwargs["debug"] = self.debug

# Execute pre run hooks with args and kwargs
Expand All @@ -618,12 +619,12 @@ def run(self, *args, grid, warmup, **kwargs):
bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)

# compute cache key
device_key = f"{target.backend}:{device}"
key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs))
kernel = self.cache[device].get(key, None)
kernel = self.cache[device_key].get(key, None)

if kernel is None:
# Kernel is not cached; we have to compile.
target = driver.active.get_current_target()
backend = self.make_backend(target)
options = backend.parse_options(kwargs)

Expand Down Expand Up @@ -664,7 +665,7 @@ def run(self, *args, grid, warmup, **kwargs):
target=target,
options=options.__dict__,
)
self.cache[device][key] = kernel
self.cache[device_key][key] = kernel

# Check that used global values have not changed.
not_present = object()
Expand Down
55 changes: 45 additions & 10 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,34 @@
import os
import subprocess
import sys
import time
from contextlib import contextmanager
from typing import Any, Dict, List
from . import language as tl


class Event:

def __init__(self, is_cpu):
self.time = 0
self.is_cpu = is_cpu
if not is_cpu:
import torch
self.cuda_event = torch.cuda.Event(enable_timing=True)

def elapsed_time(self, end_event) -> float:
if self.is_cpu:
return (end_event.time - self.time) * 1000
else:
return self.cuda_event.elapsed_time(end_event.cuda_event)

def record(self):
if self.is_cpu:
self.time = time.perf_counter()
else:
self.cuda_event.record()


def nvsmi(attrs):
attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
Expand Down Expand Up @@ -79,7 +102,8 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"):
return getattr(torch, return_mode)(times).item()


def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
is_cpu=False):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand All @@ -101,32 +125,42 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
import torch

fn()
torch.cuda.synchronize()
if not is_cpu:
torch.cuda.synchronize()

if not is_cpu:
cache_size = 256e6
device = 'cuda'
else:
# Currently, a typical L3 cache size for high-end server CPUs are ~400MB.
cache_size = 512e6
device = 'cpu'

# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
cache = torch.empty(int(cache_size), dtype=torch.int8, device=device)

# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event = Event(is_cpu)
end_event = Event(is_cpu)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
if not is_cpu:
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
start_event = [Event(is_cpu) for i in range(n_repeat)]
end_event = [Event(is_cpu) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()
Expand All @@ -145,7 +179,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
fn()
end_event[i].record()
# Record clocks
torch.cuda.synchronize()
if not is_cpu:
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
Expand Down
75 changes: 55 additions & 20 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import triton
import triton.language as tl

BLOCK_SIZE = 1024


@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
Expand Down Expand Up @@ -57,10 +59,10 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
# and (2) enqueue the above kernel with appropriate grid/block sizes:


def add(x: torch.Tensor, y: torch.Tensor):
def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
Expand All @@ -78,17 +80,37 @@ def add(x: torch.Tensor, y: torch.Tensor):

# %%
# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')

triton.runtime.driver.set_active_to_cpu()
x = torch.rand(size, device='cpu')
y = torch.rand(size, device='cpu')
output_torch_cpu = x + y
output_triton_cpu = add(x, y, is_cpu=True)
print(output_torch_cpu)
print(output_triton_cpu)
print(f'The maximum difference between torch-cpu and triton-cpu is '
f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}')

LINE_VALS = ['triton-cpu', 'torch-cpu']
LINE_NAMES = ['TritonCPU', 'TorchCPU']
LINE_STYLES = [('blue', '-'), ('green', '-')]

if triton.runtime.driver.get_active_gpus():
triton.runtime.driver.set_active_to_gpu()
x = x.to('cuda')
y = y.to('cuda')
output_torch_gpu = x + y
output_triton_gpu = add(x, y, is_cpu=False)
print(output_torch_gpu)
print(output_triton_gpu)
print(f'The maximum difference between torch-gpu and triton-gpu is '
f'{torch.max(torch.abs(output_torch_gpu - output_triton_gpu))}')

LINE_VALS += ['triton-gpu', 'torch-gpu']
LINE_NAMES += ['TritonGPU', 'TorchGPU']
LINE_STYLES += [('yellow', '-'), ('red', '-')]

# %%
# Seems like we're good to go!
Expand All @@ -108,21 +130,34 @@ def add(x: torch.Tensor, y: torch.Tensor):
x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`.
x_log=True, # x axis is logarithmic.
line_arg='provider', # Argument name whose value corresponds to a different line in the plot.
line_vals=['triton', 'torch'], # Possible values for `line_arg`.
line_names=['Triton', 'Torch'], # Label name for the lines.
styles=[('blue', '-'), ('green', '-')], # Line styles.
line_vals=LINE_VALS, # Possible values for `line_arg`.
line_names=LINE_NAMES, # Label name for the lines.
styles=LINE_STYLES, # Line styles.
ylabel='GB/s', # Label name for the y-axis.
plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot.
plot_name=
# Name for the plot. Used also as a file name for saving the plot.
f'vector-add-performance (BLOCK_SIZE={BLOCK_SIZE})',
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(size, provider):
x = torch.rand(size, device='cuda', dtype=torch.float32)
y = torch.rand(size, device='cuda', dtype=torch.float32)
device = 'cpu' if 'cpu' in provider else 'cuda'
x = torch.rand(size, device=device, dtype=torch.float32)
y = torch.rand(size, device=device, dtype=torch.float32)

if device == 'cpu':
triton.runtime.driver.set_active_to_cpu()
else:
triton.runtime.driver.set_active_to_gpu()

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
Copy link
Copy Markdown
Collaborator Author

@minjang minjang May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested on a computer without GPU. And it turned out triton.testing.do_bench has hard-coded torch.cuda usages. I will also update it.

if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, False), quantiles=quantiles)
elif provider == 'torch-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True)
gbps = lambda ms: 12 * size / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down