Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
cc_cmd += [f"-I{dir}" for dir in include_dirs]
# CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag.
if src.endswith(".cpp") or src.endswith(".cc"):
cc_cmd += ["-std=c++17"]
cc_cmd += ["-std=c++17", "-fopenmp"]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
Expand Down
49 changes: 33 additions & 16 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import triton
import triton.language as tl

BLOCK_SIZE = 1024
GPU_BLOCK_SIZE = 1024
CPU_BLOCK_SIZE = 4096
USE_GPU = True


@triton.jit
Expand Down Expand Up @@ -59,10 +61,11 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
# and (2) enqueue the above kernel with appropriate grid/block sizes:


def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu
def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, is_cpu):
if output is None:
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
Expand All @@ -72,7 +75,7 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
# - Don't forget to pass meta-parameters as keywords arguments.
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=CPU_BLOCK_SIZE if is_cpu else GPU_BLOCK_SIZE)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return output
Expand All @@ -87,22 +90,22 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
x = torch.rand(size, device='cpu')
y = torch.rand(size, device='cpu')
output_torch_cpu = x + y
output_triton_cpu = add(x, y, is_cpu=True)
output_triton_cpu = add(x, y, None, is_cpu=True)
print(output_torch_cpu)
print(output_triton_cpu)
print(f'The maximum difference between torch-cpu and triton-cpu is '
f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}')

LINE_VALS = ['triton-cpu', 'torch-cpu']
LINE_NAMES = ['TritonCPU', 'TorchCPU']
LINE_STYLES = [('blue', '-'), ('green', '-')]
LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu']
LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU']
LINE_STYLES = [('blue', '-'), ('green', '-'), ('cyan', '-')]

if triton.runtime.driver.get_active_gpus():
if USE_GPU and triton.runtime.driver.get_active_gpus():
triton.runtime.driver.set_active_to_gpu()
x = x.to('cuda')
y = y.to('cuda')
output_torch_gpu = x + y
output_triton_gpu = add(x, y, is_cpu=False)
output_triton_gpu = add(x, y, None, is_cpu=False)
print(output_torch_gpu)
print(output_triton_gpu)
print(f'The maximum difference between torch-gpu and triton-gpu is '
Expand Down Expand Up @@ -136,28 +139,42 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
ylabel='GB/s', # Label name for the y-axis.
plot_name=
# Name for the plot. Used also as a file name for saving the plot.
f'vector-add-performance (BLOCK_SIZE={BLOCK_SIZE})',
f'vector-add-performance (CPU_BLOCK_SIZE={CPU_BLOCK_SIZE}, GPU_BLOCK_SIZE={GPU_BLOCK_SIZE})',
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(size, provider):
import os

device = 'cpu' if 'cpu' in provider else 'cuda'
x = torch.rand(size, device=device, dtype=torch.float32)
y = torch.rand(size, device=device, dtype=torch.float32)

if device == 'cpu':
triton.runtime.driver.set_active_to_cpu()
if 'single' in provider:
os.environ['TRITON_CPU_SINGLE_CORE'] = '1'
else:
os.unsetenv('TRITON_CPU_SINGLE_CORE')
else:
triton.runtime.driver.set_active_to_gpu()

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, False), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles)
elif provider == 'torch-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True)
# Note that we preallocate the output buffer here to only measure the kernel performance
# without a large chunk of memory allocation.
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles,
is_cpu=True)
elif provider == 'triton-cpu-single':
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True)
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True)
gbps = lambda ms: 12 * size / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down
75 changes: 67 additions & 8 deletions third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def __new__(cls):
return cls.instance

def __init__(self):
pass
dirname = os.path.dirname(os.path.realpath(__file__))
mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils")
self.load_binary = mod.load_binary
Expand Down Expand Up @@ -182,14 +181,39 @@ def format_of(ty):

# generate glue code
src = f"""
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <string>
#include <iostream>
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <omp.h>
#include <optional>
#include <stdio.h>
#include <string>

#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
#include <stdio.h>

inline bool getBoolEnv(const std::string &env) {{
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) {{ return std::tolower(c); }});
return str == "on" || str == "true" || str == "1";
}}

inline std::optional<int64_t> getIntEnv(const std::string &env) {{
const char *cstr = std::getenv(env.c_str());
if (!cstr)
return std::nullopt;

char *endptr;
long int result = std::strtol(cstr, &endptr, 10);
if (endptr == cstr)
assert(false && "invalid integer");
return result;
}}

using kernel_ptr_t = void(*)({kernel_fn_arg_types});

Expand Down Expand Up @@ -233,20 +257,55 @@ def format_of(ty):
return ptr_info;
}}

static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
// TODO: add OMP pragmas to run in parallel
static std::unique_ptr<uint32_t[][3]> get_all_grids(uint32_t gridX, uint32_t gridY, uint32_t gridZ) {{
std::unique_ptr<uint32_t[][3]> grids(new uint32_t[gridX * gridY * gridZ][3]);
// TODO: which order would be more effective for cache locality?
for (uint32_t z = 0; z < gridZ; ++z) {{
for (uint32_t y = 0; y < gridY; ++y) {{
for (uint32_t x = 0; x < gridX; ++x) {{
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z);
grids[z * gridY * gridX + y * gridX + x][0] = x;
grids[z * gridY * gridX + y * gridX + x][1] = y;
grids[z * gridY * gridX + y * gridX + x][2] = z;
}}
}}
}}
return grids;
}}

static PyObject* launch(PyObject* self, PyObject* args) {{
static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
// TODO: Consider using omp collapse(3) clause for simplicity?
auto all_grids = get_all_grids(gridX, gridY, gridZ);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we try collapse(3) OMP rule instead of building these grids?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's good idea. But I had some troubles to get it right with clang compiler with OpenMP. E.g., only one cpu was used etc. So, my suggestion is that keep the current approach for the time being. Once we confirm MacOS+Clang works well with OpenMP, then I will try collapse clause to make the code simpler.

size_t N = gridX * gridY * gridZ;

if (getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{
if (getBoolEnv("TRITON_CPU_OMP_DEBUG"))
printf("Single core launcher\\n");

for (size_t i = 0; i < N; ++i) {{
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z);
}}
return;
}}

std::optional<int> max_threads = getIntEnv("TRITON_CPU_MAX_THREADS");
if (max_threads.has_value())
max_threads = std::max(1, std::min(max_threads.value(), omp_get_max_threads()));
else
max_threads = omp_get_max_threads();

if (getBoolEnv("TRITON_CPU_OMP_DEBUG"))
printf("N: %zu, max_threads: %d\\n", N, max_threads.value());

// For now, use the default chunk size, total iterations / max_threads.
#pragma omp parallel for schedule(static) num_threads(max_threads.value())
for (size_t i = 0; i < N; ++i) {{
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z);
}}
}}

static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
Expand Down