diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 2e7516fe69..f8eb9e9e6f 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -31,6 +31,8 @@ from typing import Optional +DEVICE = triton.runtime.driver.active.get_active_torch_device() + if torch.cuda.is_available(): from triton._C.libtriton import nvidia cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) @@ -43,6 +45,10 @@ def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" +def is_xpu(): + return triton.runtime.driver.active.get_current_target().backend == "xpu" + + def supports_tma(): return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 @@ -51,6 +57,14 @@ def supports_ws(): return is_cuda() and torch.cuda.get_device_capability()[0] >= 10 +def num_sms(): + if is_cuda(): + return torch.cuda.get_device_properties("cuda").multi_processor_count + if is_xpu(): + return torch.xpu.get_device_properties("xpu").gpu_eu_count + return 148 + + def _matmul_launch_metadata(grid, kernel, args): ret = {} M, N, K, WS = args["M"], args["N"], args["K"], args.get("WARP_SPECIALIZE", False) @@ -66,7 +80,7 @@ def _matmul_launch_metadata(grid, kernel, args): HAS_TMA_DESC = supports_tma() and hasattr(tl, "nv_tma_desc_type") -HAS_TENSOR_DESC = supports_tma() and hasattr(tl, "make_tensor_descriptor") +HAS_TENSOR_DESC = (is_xpu() or supports_tma()) and hasattr(tl, "make_tensor_descriptor") HAS_WARP_SPECIALIZE = supports_ws() and HAS_TENSOR_DESC @@ -390,7 +404,8 @@ def matmul_persistent(a, b): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.dtype == b.dtype, "Incompatible dtypes" - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + NUM_SMS = num_sms() + M, K = a.shape K, N = b.shape dtype = a.dtype @@ -504,7 +519,7 @@ def matmul_tma_persistent(a, b, warp_specialize: bool): desc_helper.init_tma_descriptor("b") desc_helper.init_tma_descriptor("c") - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + NUM_SMS = num_sms() def grid(META): nonlocal desc_helper @@ -649,11 +664,11 @@ def matmul_descriptor_persistent(a, b, warp_specialize: bool): dtype = a.dtype c = torch.empty((M, N), device=a.device, dtype=dtype) - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + NUM_SMS = num_sms() # TMA descriptors require a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): - return torch.empty(size, device="cuda", dtype=torch.int8) + return torch.empty(size, device=DEVICE, dtype=torch.int8) triton.set_allocator(alloc_fn) @@ -706,17 +721,19 @@ def bench_fn(label, reps, warmup_reps, fn, *args): print(f"Benchmarking {label}: ...", end="") for _ in range(warmup_reps): fn(*args) - with proton_context(): - for _ in range(reps): - fn(*args) + #FIXME: Enable for XPU once proton support works. + if is_cuda(): + with proton_context(): + for _ in range(reps): + fn(*args) print(f"\rBenchmarking {label}: done") def bench(K, dtype, reps=10000, warmup_reps=10000): M = 8192 N = 8192 - a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) - b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) + a = torch.randn((M, K), device=DEVICE, dtype=torch.float16).to(dtype) + b = torch.randn((K, N), device=DEVICE, dtype=torch.float16).to(dtype) b = b.T.contiguous() @@ -750,8 +767,8 @@ def run_test(expect, fn, a, b, label, enabled=True): def validate(M, N, K, dtype): print(f"{M=}, {N=}, {K=}, verification naive vs: ") - a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) - b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) + a = torch.randn((M, K), device=DEVICE, dtype=torch.float16).to(dtype) + b = torch.randn((K, N), device=DEVICE, dtype=torch.float16).to(dtype) b = b.T.contiguous() naive_result = matmul(a, b.T).to(torch.float16) @@ -806,10 +823,11 @@ def show_profile(precision, profile_name): validate(32, 32, 32, dtype) validate(8192, 8192, args.K_range[0], dtype) - - proton.start("matmul", hook="triton") - proton.deactivate() + if is_cuda(): + proton.start("matmul", hook="triton") + proton.deactivate() for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): bench(K, dtype) - proton.finalize() - show_profile(args.prec, "matmul") + if is_cuda(): + proton.finalize() + show_profile(args.prec, "matmul") diff --git a/scripts/skiplist/lts/tutorials.txt b/scripts/skiplist/lts/tutorials.txt index e223307d02..bcabc74094 100644 --- a/scripts/skiplist/lts/tutorials.txt +++ b/scripts/skiplist/lts/tutorials.txt @@ -1,5 +1,6 @@ 03-matrix-multiplication 06-fused-attention 08-grouped-gemm +09-persistent-matmul 10-experimental-block-pointer 10i-experimental-block-pointer diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index aa6a1871b5..0edeacd96f 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -252,6 +252,7 @@ run_tutorial_tests() { run_tutorial_test "06-fused-attention" run_tutorial_test "07-extern-functions" run_tutorial_test "08-grouped-gemm" + TRITON_TEST_REPORTS=false run_tutorial_test "09-persistent-matmul" run_tutorial_test "10-experimental-block-pointer" run_tutorial_test "10i-experimental-block-pointer" diff --git a/test/TritonIntelGPU/combine.mlir b/test/TritonIntelGPU/combine.mlir index 72f74ab465..131de33d98 100644 --- a/test/TritonIntelGPU/combine.mlir +++ b/test/TritonIntelGPU/combine.mlir @@ -2410,7 +2410,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // COM: Reproducer for issue #3817 (to ensure that the compiler doesn't crash). // CHECK: #[[$BLOCKED1:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 8], warpsPerCTA = [4, 1], order = [1, 0]}> - #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index fab64685aa..e1fa9f885a 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -1284,7 +1284,6 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { auto maskOrder = linAttr.getOrder(); if (maskOrder[0] >= axisInfo->getRank()) return 1; - auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " << alignment);