Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions python/tutorials/09-persistent-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

from typing import Optional

DEVICE = triton.runtime.driver.active.get_active_torch_device()
Comment thread
etiotto marked this conversation as resolved.

if torch.cuda.is_available():
from triton._C.libtriton import nvidia
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
Expand All @@ -43,6 +45,10 @@ def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"


def is_xpu():
return triton.runtime.driver.active.get_current_target().backend == "xpu"


def supports_tma():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9

Expand All @@ -51,6 +57,14 @@ def supports_ws():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 10


def num_sms():
if is_cuda():
return torch.cuda.get_device_properties("cuda").multi_processor_count
if is_xpu():
return torch.xpu.get_device_properties("xpu").gpu_eu_count
return 148


def _matmul_launch_metadata(grid, kernel, args):
ret = {}
M, N, K, WS = args["M"], args["N"], args["K"], args.get("WARP_SPECIALIZE", False)
Expand All @@ -66,7 +80,7 @@ def _matmul_launch_metadata(grid, kernel, args):


HAS_TMA_DESC = supports_tma() and hasattr(tl, "nv_tma_desc_type")
HAS_TENSOR_DESC = supports_tma() and hasattr(tl, "make_tensor_descriptor")
HAS_TENSOR_DESC = (is_xpu() or supports_tma()) and hasattr(tl, "make_tensor_descriptor")
HAS_WARP_SPECIALIZE = supports_ws() and HAS_TENSOR_DESC


Expand Down Expand Up @@ -390,7 +404,8 @@ def matmul_persistent(a, b):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
NUM_SMS = num_sms()

M, K = a.shape
K, N = b.shape
dtype = a.dtype
Expand Down Expand Up @@ -504,7 +519,7 @@ def matmul_tma_persistent(a, b, warp_specialize: bool):
desc_helper.init_tma_descriptor("b")
desc_helper.init_tma_descriptor("c")

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
NUM_SMS = num_sms()

def grid(META):
nonlocal desc_helper
Expand Down Expand Up @@ -649,11 +664,11 @@ def matmul_descriptor_persistent(a, b, warp_specialize: bool):
dtype = a.dtype

c = torch.empty((M, N), device=a.device, dtype=dtype)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
NUM_SMS = num_sms()

# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device="cuda", dtype=torch.int8)
return torch.empty(size, device=DEVICE, dtype=torch.int8)

triton.set_allocator(alloc_fn)

Expand Down Expand Up @@ -706,17 +721,19 @@ def bench_fn(label, reps, warmup_reps, fn, *args):
print(f"Benchmarking {label}: ...", end="")
for _ in range(warmup_reps):
fn(*args)
with proton_context():
for _ in range(reps):
fn(*args)
#FIXME: Enable for XPU once proton support works.
Comment thread
whitneywhtsang marked this conversation as resolved.
if is_cuda():
with proton_context():
for _ in range(reps):
fn(*args)
print(f"\rBenchmarking {label}: done")


def bench(K, dtype, reps=10000, warmup_reps=10000):
M = 8192
N = 8192
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16).to(dtype)

b = b.T.contiguous()

Expand Down Expand Up @@ -750,8 +767,8 @@ def run_test(expect, fn, a, b, label, enabled=True):

def validate(M, N, K, dtype):
print(f"{M=}, {N=}, {K=}, verification naive vs: ")
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16).to(dtype)
b = b.T.contiguous()

naive_result = matmul(a, b.T).to(torch.float16)
Expand Down Expand Up @@ -806,10 +823,11 @@ def show_profile(precision, profile_name):

validate(32, 32, 32, dtype)
validate(8192, 8192, args.K_range[0], dtype)

proton.start("matmul", hook="triton")
proton.deactivate()
if is_cuda():
proton.start("matmul", hook="triton")
proton.deactivate()
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
bench(K, dtype)
proton.finalize()
show_profile(args.prec, "matmul")
if is_cuda():
proton.finalize()
show_profile(args.prec, "matmul")
1 change: 1 addition & 0 deletions scripts/skiplist/lts/tutorials.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
03-matrix-multiplication
06-fused-attention
08-grouped-gemm
09-persistent-matmul
10-experimental-block-pointer
10i-experimental-block-pointer
1 change: 1 addition & 0 deletions scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ run_tutorial_tests() {
run_tutorial_test "06-fused-attention"
run_tutorial_test "07-extern-functions"
run_tutorial_test "08-grouped-gemm"
TRITON_TEST_REPORTS=false run_tutorial_test "09-persistent-matmul"
run_tutorial_test "10-experimental-block-pointer"
run_tutorial_test "10i-experimental-block-pointer"

Expand Down
1 change: 0 additions & 1 deletion test/TritonIntelGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2410,7 +2410,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
// COM: Reproducer for issue #3817 (to ensure that the compiler doesn't crash).

// CHECK: #[[$BLOCKED1:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 8], warpsPerCTA = [4, 1], order = [1, 0]}>

#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
Expand Down
1 change: 0 additions & 1 deletion third_party/intel/lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1284,7 +1284,6 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto maskOrder = linAttr.getOrder();
if (maskOrder[0] >= axisInfo->getRank())
return 1;

auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
<< alignment);
Expand Down