From c04588793a7ed4a23c223b0668a5fad55535184f Mon Sep 17 00:00:00 2001 From: minjang Date: Tue, 4 Jun 2024 11:22:06 -0700 Subject: [PATCH 1/4] [CPU] Add OpenMP launcher --- include/triton/Tools/Sys/GetEnv.hpp | 19 +++++++++++ python/triton/runtime/build.py | 2 +- python/tutorials/01-vector-add.py | 21 ++++++++---- third_party/cpu/backend/driver.py | 52 ++++++++++++++++++++++++++--- 4 files changed, 82 insertions(+), 12 deletions(-) diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index a01c4adbba64..08bbd0525b45 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -29,6 +30,9 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { inline const std::set CACHE_NEUTRAL_ENV_VARS = { "TRITON_REPRODUCER_PATH", + "TRITON_CPU_SINGLE_CORE", + "TRITON_CPU_MAX_THREADS", + "TRITON_CPU_OMP_DEBUG", }; namespace tools { @@ -52,6 +56,21 @@ inline std::string getStrEnv(const std::string &env) { return result; } +inline std::optional getIntEnv(const std::string &env) { + assertIsRecognized(env); + const char *cstr = std::getenv(env.c_str()); + if (!cstr) { + return std::nullopt; + } + + char *endptr; + long int result = std::strtol(cstr, &endptr, 10); + if (endptr == cstr) { + assert(false && "invalid integer"); + } + return result; +} + // return value of a cache-invalidating boolean environment variable inline bool getBoolEnv(const std::string &env) { assertIsRecognized(env); diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 2d565539176c..bd4d94c5f6b4 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -47,7 +47,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): cc_cmd += [f"-I{dir}" for dir in include_dirs] # CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag. if src.endswith(".cpp") or src.endswith(".cc"): - cc_cmd += ["-std=c++17"] + cc_cmd += ["-std=c++17", "-fopenmp"] ret = subprocess.check_call(cc_cmd) if ret == 0: return so diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 57e8b0996a6b..ab7790aeae41 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -23,7 +23,8 @@ import triton import triton.language as tl -BLOCK_SIZE = 1024 +GPU_BLOCK_SIZE = 1024 +CPU_BLOCK_SIZE = 128 @triton.jit @@ -72,7 +73,7 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu): # - Each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. # - Don't forget to pass meta-parameters as keywords arguments. - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=CPU_BLOCK_SIZE if is_cpu else GPU_BLOCK_SIZE) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. return output @@ -93,9 +94,9 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu): print(f'The maximum difference between torch-cpu and triton-cpu is ' f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') -LINE_VALS = ['triton-cpu', 'torch-cpu'] -LINE_NAMES = ['TritonCPU', 'TorchCPU'] -LINE_STYLES = [('blue', '-'), ('green', '-')] +LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu'] +LINE_NAMES = ['TritonCPU 1-core', 'TritonCPU', 'TorchCPU'] +LINE_STYLES = [('blue', '-'), ('cyan', '-'), ('green', '-')] if triton.runtime.driver.get_active_gpus(): triton.runtime.driver.set_active_to_gpu() @@ -136,16 +137,22 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu): ylabel='GB/s', # Label name for the y-axis. plot_name= # Name for the plot. Used also as a file name for saving the plot. - f'vector-add-performance (BLOCK_SIZE={BLOCK_SIZE})', + f'vector-add-performance (CPU_BLOCK_SIZE={CPU_BLOCK_SIZE}, GPU_BLOCK_SIZE={GPU_BLOCK_SIZE})', args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(size, provider): + import os + device = 'cpu' if 'cpu' in provider else 'cuda' x = torch.rand(size, device=device, dtype=torch.float32) y = torch.rand(size, device=device, dtype=torch.float32) if device == 'cpu': triton.runtime.driver.set_active_to_cpu() + if 'single' in provider: + os.environ['TRITON_CPU_SINGLE_CORE'] = '1' + else: + os.unsetenv('TRITON_CPU_SINGLE_CORE') else: triton.runtime.driver.set_active_to_gpu() @@ -156,6 +163,8 @@ def benchmark(size, provider): ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, False), quantiles=quantiles) elif provider == 'torch-cpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True) + elif provider == 'triton-cpu-single': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True) gbps = lambda ms: 12 * size / ms * 1e-6 diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 5783a0342dbd..6eb0acc2926c 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -16,6 +16,7 @@ include_dir = [ os.path.join(dirname, "include"), os.path.join(llvm_root, "include"), + os.path.join(".", "include"), ] library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")] libraries = [ @@ -77,6 +78,8 @@ "z", ] +MINIMUM_OMP_CHUNK_SIZE = 10 + def compile_module_from_src(src, name): key = hashlib.md5(src.encode("utf-8")).hexdigest() @@ -110,7 +113,6 @@ def __new__(cls): return cls.instance def __init__(self): - pass dirname = os.path.dirname(os.path.realpath(__file__)) mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils") self.load_binary = mod.load_binary @@ -186,6 +188,10 @@ def format_of(ty): #include #include #include +#include +#include + +#include "triton/Tools/Sys/GetEnv.hpp" #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #include @@ -233,20 +239,56 @@ def format_of(ty): return ptr_info; }} -static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - // TODO: add OMP pragmas to run in parallel +static std::unique_ptr get_all_grids(uint32_t gridX, uint32_t gridY, uint32_t gridZ) {{ + std::unique_ptr grids(new uint32_t[gridX * gridY * gridZ][3]); + // TODO: which order would be more effective for cache locality? for (uint32_t z = 0; z < gridZ; ++z) {{ for (uint32_t y = 0; y < gridY; ++y) {{ for (uint32_t x = 0; x < gridX; ++x) {{ - (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); + grids[z * gridY * gridX + y * gridX + x][0] = x; + grids[z * gridY * gridX + y * gridX + x][1] = y; + grids[z * gridY * gridX + y * gridX + x][2] = z; }} }} }} + return grids; }} -static PyObject* launch(PyObject* self, PyObject* args) {{ +static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + auto all_grids = get_all_grids(gridX, gridY, gridZ); + size_t N = gridX * gridY * gridZ; + + if (mlir::triton::tools::getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{ + if (mlir::triton::tools::getBoolEnv("TRITON_CPU_OMP_DEBUG")) + printf("Single core launcher\\n"); + for (uint32_t i = 0; i < N; ++i) {{ + const auto [x, y, z] = all_grids[i]; + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); + }} + return; + }} + // Use the default static scheduling with a simple chuck size policy. + std::optional max_threads = mlir::triton::tools::getIntEnv("TRITON_CPU_MAX_THREADS"); + if (max_threads.has_value()) + max_threads = std::max(1, std::min(max_threads.value(), omp_get_max_threads())); + else + max_threads = omp_get_max_threads(); + int chunk_size = std::ceil((double)N / (double)max_threads.value()); + chunk_size = std::max(chunk_size, {MINIMUM_OMP_CHUNK_SIZE}); + + if (mlir::triton::tools::getBoolEnv("TRITON_CPU_OMP_DEBUG")) + printf("N: %zu, max_threads: %d, chunk_size: %zu\\n", N, max_threads, chunk_size); + +#pragma omp parallel for schedule(static, chunk_size) num_threads(max_threads.value()) + for (uint32_t i = 0; i < N; ++i) {{ + const auto [x, y, z] = all_grids[i]; + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); + }} +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ int gridX, gridY, gridZ; PyObject *launch_enter_hook = NULL; PyObject *launch_exit_hook = NULL; From 7a22fc7cd58ddf189f6302c2ed9e0ee8157828a7 Mon Sep 17 00:00:00 2001 From: minjang Date: Fri, 7 Jun 2024 01:15:01 -0700 Subject: [PATCH 2/4] Address the comments --- include/triton/Tools/Sys/GetEnv.hpp | 19 ---------- python/tutorials/01-vector-add.py | 38 ++++++++++++-------- third_party/cpu/backend/driver.py | 55 +++++++++++++++++++---------- 3 files changed, 60 insertions(+), 52 deletions(-) diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 08bbd0525b45..a01c4adbba64 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include @@ -30,9 +29,6 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { inline const std::set CACHE_NEUTRAL_ENV_VARS = { "TRITON_REPRODUCER_PATH", - "TRITON_CPU_SINGLE_CORE", - "TRITON_CPU_MAX_THREADS", - "TRITON_CPU_OMP_DEBUG", }; namespace tools { @@ -56,21 +52,6 @@ inline std::string getStrEnv(const std::string &env) { return result; } -inline std::optional getIntEnv(const std::string &env) { - assertIsRecognized(env); - const char *cstr = std::getenv(env.c_str()); - if (!cstr) { - return std::nullopt; - } - - char *endptr; - long int result = std::strtol(cstr, &endptr, 10); - if (endptr == cstr) { - assert(false && "invalid integer"); - } - return result; -} - // return value of a cache-invalidating boolean environment variable inline bool getBoolEnv(const std::string &env) { assertIsRecognized(env); diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index ab7790aeae41..ca2d71c2746d 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -24,7 +24,8 @@ import triton.language as tl GPU_BLOCK_SIZE = 1024 -CPU_BLOCK_SIZE = 128 +CPU_BLOCK_SIZE = 4096 +USE_GPU = True @triton.jit @@ -60,10 +61,11 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. # and (2) enqueue the above kernel with appropriate grid/block sizes: -def add(x: torch.Tensor, y: torch.Tensor, is_cpu): - # We need to preallocate the output. - output = torch.empty_like(x) - assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu +def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, is_cpu): + if output is None: + # We need to preallocate the output. + output = torch.empty_like(x) + assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu n_elements = output.numel() # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. @@ -88,22 +90,22 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu): x = torch.rand(size, device='cpu') y = torch.rand(size, device='cpu') output_torch_cpu = x + y -output_triton_cpu = add(x, y, is_cpu=True) +output_triton_cpu = add(x, y, None, is_cpu=True) print(output_torch_cpu) print(output_triton_cpu) print(f'The maximum difference between torch-cpu and triton-cpu is ' f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') -LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu'] -LINE_NAMES = ['TritonCPU 1-core', 'TritonCPU', 'TorchCPU'] -LINE_STYLES = [('blue', '-'), ('cyan', '-'), ('green', '-')] +LINE_VALS = ['triton-cpu-single-prealloc', 'triton-cpu-prealloc', 'triton-cpu-single', 'triton-cpu', 'torch-cpu'] +LINE_NAMES = ['TritonCPU 1+pre', 'TritonCPU pre', 'TritonCPU 1', 'TritonCPU', 'TorchCPU'] +LINE_STYLES = [('blue', '--'), ('green', '--'), ('blue', '-'), ('green', '-'), ('cyan', '-')] -if triton.runtime.driver.get_active_gpus(): +if USE_GPU and triton.runtime.driver.get_active_gpus(): triton.runtime.driver.set_active_to_gpu() x = x.to('cuda') y = y.to('cuda') output_torch_gpu = x + y - output_triton_gpu = add(x, y, is_cpu=False) + output_triton_gpu = add(x, y, None, is_cpu=False) print(output_torch_gpu) print(output_triton_gpu) print(f'The maximum difference between torch-gpu and triton-gpu is ' @@ -160,13 +162,21 @@ def benchmark(size, provider): if provider == 'torch-gpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, False), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles) elif provider == 'torch-cpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, True), quantiles=quantiles, is_cpu=True) + elif provider == 'triton-cpu-prealloc': + # Note that we preallocate the output buffer here to only measure the kernel performance + # without a large chunk of memory allocation. + output = torch.empty_like(x) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) + elif provider == 'triton-cpu-single-prealloc': + output = torch.empty_like(x) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, True), quantiles=quantiles, is_cpu=True) gbps = lambda ms: 12 * size / ms * 1e-6 return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 6eb0acc2926c..6278dbfc46b4 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -16,7 +16,6 @@ include_dir = [ os.path.join(dirname, "include"), os.path.join(llvm_root, "include"), - os.path.join(".", "include"), ] library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")] libraries = [ @@ -78,8 +77,6 @@ "z", ] -MINIMUM_OMP_CHUNK_SIZE = 10 - def compile_module_from_src(src, name): key = hashlib.md5(src.encode("utf-8")).hexdigest() @@ -184,18 +181,39 @@ def format_of(ty): # generate glue code src = f""" +#include +#include #include -#include -#include +#include #include +#include #include -#include - -#include "triton/Tools/Sys/GetEnv.hpp" +#include +#include +#include #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #include -#include + +inline bool getBoolEnv(const std::string &env) {{ + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) {{ return std::tolower(c); }}); + return str == "on" || str == "true" || str == "1"; +}} + +inline std::optional getIntEnv(const std::string &env) {{ + const char *cstr = std::getenv(env.c_str()); + if (!cstr) + return std::nullopt; + + char *endptr; + long int result = std::strtol(cstr, &endptr, 10); + if (endptr == cstr) + assert(false && "invalid integer"); + return result; +}} using kernel_ptr_t = void(*)({kernel_fn_arg_types}); @@ -255,12 +273,14 @@ def format_of(ty): }} static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + // TODO: Consider using omp collapse(3) clause for simplicity? auto all_grids = get_all_grids(gridX, gridY, gridZ); size_t N = gridX * gridY * gridZ; - if (mlir::triton::tools::getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{ - if (mlir::triton::tools::getBoolEnv("TRITON_CPU_OMP_DEBUG")) + if (getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{ + if (getBoolEnv("TRITON_CPU_OMP_DEBUG")) printf("Single core launcher\\n"); + for (uint32_t i = 0; i < N; ++i) {{ const auto [x, y, z] = all_grids[i]; (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); @@ -268,20 +288,17 @@ def format_of(ty): return; }} - // Use the default static scheduling with a simple chuck size policy. - std::optional max_threads = mlir::triton::tools::getIntEnv("TRITON_CPU_MAX_THREADS"); + std::optional max_threads = getIntEnv("TRITON_CPU_MAX_THREADS"); if (max_threads.has_value()) max_threads = std::max(1, std::min(max_threads.value(), omp_get_max_threads())); else max_threads = omp_get_max_threads(); - int chunk_size = std::ceil((double)N / (double)max_threads.value()); - chunk_size = std::max(chunk_size, {MINIMUM_OMP_CHUNK_SIZE}); - - if (mlir::triton::tools::getBoolEnv("TRITON_CPU_OMP_DEBUG")) - printf("N: %zu, max_threads: %d, chunk_size: %zu\\n", N, max_threads, chunk_size); + if (getBoolEnv("TRITON_CPU_OMP_DEBUG")) + printf("N: %zu, max_threads: %d\\n", N, max_threads.value()); -#pragma omp parallel for schedule(static, chunk_size) num_threads(max_threads.value()) + // For now, use the default chunk size, total iterations / max_threads. +#pragma omp parallel for schedule(static) num_threads(max_threads.value()) for (uint32_t i = 0; i < N; ++i) {{ const auto [x, y, z] = all_grids[i]; (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); From 14388566e310a4e001051c6c8a8ae4da992894bb Mon Sep 17 00:00:00 2001 From: minjang Date: Fri, 7 Jun 2024 01:36:20 -0700 Subject: [PATCH 3/4] Fix induction variable type --- third_party/cpu/backend/driver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 6278dbfc46b4..1018f64d5b35 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -281,7 +281,7 @@ def format_of(ty): if (getBoolEnv("TRITON_CPU_OMP_DEBUG")) printf("Single core launcher\\n"); - for (uint32_t i = 0; i < N; ++i) {{ + for (size_t i = 0; i < N; ++i) {{ const auto [x, y, z] = all_grids[i]; (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); }} @@ -299,7 +299,7 @@ def format_of(ty): // For now, use the default chunk size, total iterations / max_threads. #pragma omp parallel for schedule(static) num_threads(max_threads.value()) - for (uint32_t i = 0; i < N; ++i) {{ + for (size_t i = 0; i < N; ++i) {{ const auto [x, y, z] = all_grids[i]; (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); }} From 065fdc9f7b3bf8f27d37355025755877ff571456 Mon Sep 17 00:00:00 2001 From: minjang Date: Fri, 7 Jun 2024 13:27:31 -0700 Subject: [PATCH 4/4] Always use preallocated output buffer for CPU with torch.add --- python/tutorials/01-vector-add.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index ca2d71c2746d..49aed23e023e 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -96,9 +96,9 @@ def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, is_cpu): print(f'The maximum difference between torch-cpu and triton-cpu is ' f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') -LINE_VALS = ['triton-cpu-single-prealloc', 'triton-cpu-prealloc', 'triton-cpu-single', 'triton-cpu', 'torch-cpu'] -LINE_NAMES = ['TritonCPU 1+pre', 'TritonCPU pre', 'TritonCPU 1', 'TritonCPU', 'TorchCPU'] -LINE_STYLES = [('blue', '--'), ('green', '--'), ('blue', '-'), ('green', '-'), ('cyan', '-')] +LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu'] +LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU'] +LINE_STYLES = [('blue', '-'), ('green', '-'), ('cyan', '-')] if USE_GPU and triton.runtime.driver.get_active_gpus(): triton.runtime.driver.set_active_to_gpu() @@ -164,19 +164,17 @@ def benchmark(size, provider): elif provider == 'triton-gpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles) elif provider == 'torch-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True) - elif provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, True), quantiles=quantiles, is_cpu=True) - elif provider == 'triton-cpu-prealloc': # Note that we preallocate the output buffer here to only measure the kernel performance # without a large chunk of memory allocation. output = torch.empty_like(x) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) - elif provider == 'triton-cpu-single-prealloc': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles, + is_cpu=True) + elif provider == 'triton-cpu-single': output = torch.empty_like(x) ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, True), quantiles=quantiles, is_cpu=True) + output = torch.empty_like(x) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) gbps = lambda ms: 12 * size / ms * 1e-6 return gbps(ms), gbps(max_ms), gbps(min_ms)