From 73f3526cc7246e6363bc334857e5108024a8051b Mon Sep 17 00:00:00 2001 From: Ruiqi Gao Date: Fri, 14 Jun 2024 15:06:50 -0700 Subject: [PATCH] Update matrix-multiplication-cpu tutorial, use preallocated output buffer for CPU. --- .../tutorials/03-matrix-multiplication-cpu.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index 0cc90a474052..c20a36aab10e 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -249,15 +249,19 @@ def matmul_kernel( # and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. -def matmul(a, b): +def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.is_contiguous(), "Matrix A must be contiguous" M, K = a.shape K, N = b.shape + #TODO: Currently masked load is not supported yet. assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=a.dtype) + if c is None: + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + else: + assert c.shape == (M, N), "Incompatible dimensions" # 1D launch kernel where each block gets its own program. grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), ) matmul_kernel[grid]( @@ -284,10 +288,9 @@ def matmul(a, b): triton.runtime.driver.set_active_to_cpu() - a = torch.randn((512, 512), device='cpu', dtype=torch.float32) b = torch.randn((512, 512), device='cpu', dtype=torch.float32) -triton_output = matmul(a, b) +triton_output = matmul(a, b, None) torch_output = torch.matmul(a, b) print(f"triton_cpu_output_with_{a.dtype}_inputs={triton_output}") print(f"torch_cpu_output_with_{a.dtype}_inputs={torch_output}") @@ -315,7 +318,7 @@ def matmul(a, b): triton.runtime.driver.set_active_to_gpu() a = a.to('cuda') b = b.to('cuda') - triton_output = matmul(a, b) + triton_output = matmul(a, b, None) torch_output = torch.matmul(a, b) print(f"triton_gpu_output_with_{a.dtype}_inputs={triton_output}") print(f"torch_gpu_output_with_{a.dtype}_inputs={torch_output}") @@ -377,13 +380,16 @@ def benchmark(M, N, K, provider): if provider == 'torch-gpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles) elif provider == 'torch-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms)