Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,19 @@ def matmul_kernel(
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.


def matmul(a, b):
def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
#TODO: Currently masked load is not supported yet.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really tiny nit: next time do like:

# TODO: Currently masked load is not supported yet.

assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
if c is None:
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
else:
assert c.shape == (M, N), "Incompatible dimensions"
# 1D launch kernel where each block gets its own program.
grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), )
matmul_kernel[grid](
Expand All @@ -284,10 +288,9 @@ def matmul(a, b):

triton.runtime.driver.set_active_to_cpu()


a = torch.randn((512, 512), device='cpu', dtype=torch.float32)
b = torch.randn((512, 512), device='cpu', dtype=torch.float32)
triton_output = matmul(a, b)
triton_output = matmul(a, b, None)
torch_output = torch.matmul(a, b)
print(f"triton_cpu_output_with_{a.dtype}_inputs={triton_output}")
print(f"torch_cpu_output_with_{a.dtype}_inputs={torch_output}")
Expand Down Expand Up @@ -315,7 +318,7 @@ def matmul(a, b):
triton.runtime.driver.set_active_to_gpu()
a = a.to('cuda')
b = b.to('cuda')
triton_output = matmul(a, b)
triton_output = matmul(a, b, None)
torch_output = torch.matmul(a, b)
print(f"triton_gpu_output_with_{a.dtype}_inputs={triton_output}")
print(f"torch_gpu_output_with_{a.dtype}_inputs={torch_output}")
Expand Down Expand Up @@ -377,13 +380,16 @@ def benchmark(M, N, K, provider):
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles)
elif provider == 'torch-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True)
perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)

Expand Down