From f11bf68683c8a7c857b62208339460ca5d62b7dd Mon Sep 17 00:00:00 2001 From: Zeng Wu Date: Wed, 17 Jul 2024 22:43:56 -0700 Subject: [PATCH 1/6] Emit SWP failure warning if enabled in the commandline --- .../Pipeliner/SoftwarePipeliner.cpp | 7 ++++ test/TritonGPU/swp-warning.mlir | 40 +++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 test/TritonGPU/swp-warning.mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp index 8766e82b9f15..37f3f9d13429 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -25,6 +25,10 @@ // to create async operations and create a modulo schedule. Then we call the // expander to generate the prologue and new loop. //===----------------------------------------------------------------------===// +static llvm::cl::opt + DumpSWPFailure("dump-swp-failure", llvm::cl::Hidden, + llvm::cl::init(false), + llvm::cl::desc("dump warning if SWP fails")); namespace mlir { namespace triton { @@ -123,6 +127,9 @@ struct PipelinePass : public impl::TritonGPUPipelineBase { auto outerLoop = dyn_cast(forOp->getParentOp()); int loopNumStages = getNumStagesOrDefault(forOp); bool pipelined = pipelineLoop(forOp, loopNumStages); + if (DumpSWPFailure && !pipelined) { + forOp->emitRemark("SWP failes in inner most loop"); + } if (pipelined && outerLoop && getNumStagesOrDefault(outerLoop) > 1) outerLoops.insert(outerLoop); } diff --git a/test/TritonGPU/swp-warning.mlir b/test/TritonGPU/swp-warning.mlir new file mode 100644 index 000000000000..b262e9dd91c1 --- /dev/null +++ b/test/TritonGPU/swp-warning.mlir @@ -0,0 +1,40 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -dump-swp-failure | FileCheck %s + +// CHECK-LABEL: @dont_pipeline_128x1 +// CHECK-NOT: local_load{{.*}}128x1 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @dont_pipeline_128x1(%arg6: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false, swp = true} { + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst_4 = arith.constant dense<-1.000000e+30> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + + %99:1 = scf.for %arg25 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg31 = %cst_4) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) : i32 { + %94 = tt.splat %arg6 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %151 = tt.load %94 : tensor<128x1x!tt.ptr, #blocked> + %161 = triton_gpu.convert_layout %151 : tensor<128x1xi32, #blocked> -> tensor<128x1xi32, #mma> + %162 = tt.broadcast %161 : tensor<128x1xi32, #mma> -> tensor<128x64xi32, #mma> + %170 = arith.sitofp %162 : tensor<128x64xi32, #mma> to tensor<128x64xf32, #mma> + + %173 = "tt.reduce"(%170) <{axis = 1 : i32}> ({ + ^bb0(%arg33: f32, %arg34: f32): + %207 = arith.maxnumf %arg33, %arg34 : f32 + tt.reduce.return %207 : f32 + }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %175 = arith.maxnumf %arg31, %173 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + + %201 = arith.truncf %170 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + %202 = triton_gpu.convert_layout %201 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + + %192 = arith.constant dense<0.> : tensor<128x64xf32, #mma> + %203 = arith.constant dense<0.> : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %204 = tt.dot %202, %203, %192 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + + scf.yield %175 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } + tt.return + } +} From 4b075ec15ba8e34117a7d0d6d1758b920f0098e2 Mon Sep 17 00:00:00 2001 From: Zeng Wu Date: Wed, 17 Jul 2024 22:43:56 -0700 Subject: [PATCH 2/6] Emit perf warning for the failure of SWP, vectorization --- .../Pipeliner/SoftwarePipeliner.cpp | 26 ++-- python/test/unit/test_perf_warning.py | 111 +++++++++++++++++- test/TritonGPU/swp-warning.mlir | 40 ------- .../LoadStoreOpToLLVM.cpp | 13 ++ 4 files changed, 137 insertions(+), 53 deletions(-) delete mode 100644 test/TritonGPU/swp-warning.mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp index 37f3f9d13429..3a9a7e68094f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -25,10 +25,6 @@ // to create async operations and create a modulo schedule. Then we call the // expander to generate the prologue and new loop. //===----------------------------------------------------------------------===// -static llvm::cl::opt - DumpSWPFailure("dump-swp-failure", llvm::cl::Hidden, - llvm::cl::init(false), - llvm::cl::desc("dump warning if SWP fails")); namespace mlir { namespace triton { @@ -45,8 +41,11 @@ static bool preCondition(scf::ForOp forOp) { [](Value operand) { Operation *def = operand.getDefiningOp(); return !def; - })) + })) { + forOp->emitRemark() << "Warning: SWP fails due to loop distance is greater than 1"; return false; + } + // Don't pipeline outer loops. if (forOp ->walk([&](Operation *op) { @@ -56,8 +55,10 @@ static bool preCondition(scf::ForOp forOp) { return WalkResult::interrupt(); return WalkResult::advance(); }) - .wasInterrupted()) + .wasInterrupted()) { + forOp->emitRemark() << "Warning: SWP fails on the outer loop"; return false; + } return true; } @@ -106,9 +107,10 @@ struct PipelinePass : public impl::TritonGPUPipelineBase { // global control. if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) return numStages; + return mlir::cast( - forOp->getAttr(mlir::triton::kNumStagesAttrName)) - .getInt(); + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); } void runOnOperation() override { @@ -119,17 +121,17 @@ struct PipelinePass : public impl::TritonGPUPipelineBase { loops.push_back(forOp); }); - if (loops.empty()) + if (loops.empty()) { + auto op = getOperation(); + op->emitRemark() << "Warning: SWP fails. There is no loop with num_stages greater than 1"; return; + } llvm::SmallSetVector outerLoops; for (scf::ForOp forOp : loops) { auto outerLoop = dyn_cast(forOp->getParentOp()); int loopNumStages = getNumStagesOrDefault(forOp); bool pipelined = pipelineLoop(forOp, loopNumStages); - if (DumpSWPFailure && !pipelined) { - forOp->emitRemark("SWP failes in inner most loop"); - } if (pipelined && outerLoop && getNumStagesOrDefault(outerLoop) > 1) outerLoops.insert(outerLoop); } diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 871bc6ba294b..8504a120eafa 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -13,7 +13,7 @@ def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" -def test_mma_remark(capfd): +def test_remark_mma(capfd): if is_cuda(): capability = torch.cuda.get_device_capability() if capability[0] < 9: @@ -45,3 +45,112 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, assert "remark: Warning: can't use MMA V3 for the dot op" in captured.err, "expect MMA V3 remark" assert "note: see current operation:" in captured.err os.environ['MLIR_ENABLE_REMARK'] = '0' + + +def test_remark_swp_1stage(capfd): + os.environ['MLIR_ENABLE_REMARK'] = '1' + + @triton.jit + def vecadd_kernel(a_ptr, b_ptr, output_ptr, n_elements, num_blocks): + pid = tl.program_id(axis=0) + block_start = pid * 128 * num_blocks + offsets = block_start + tl.arange(0, 128) + for _ in tl.range(0, num_blocks, num_stages=1): + mask = offsets < n_elements + x = tl.load(a_ptr + offsets, mask=mask) + y = tl.load(b_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + offsets += 128 + + triton.compile( + triton.compiler.ASTSource( + fn=vecadd_kernel, signature={ + 0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, + constants={})) + _, err = capfd.readouterr() + + assert "remark: Warning: SWP fails. There is no loop with num_stages greater than 1" in err, "expect SWP failure remark" + # assert "note: see current operation:" in captured.err + # assert "numstages in loop" in captured.err + os.environ['MLIR_ENABLE_REMARK'] = '0' + + +def test_remark_swp_dep_distance(capfd): + os.environ['MLIR_ENABLE_REMARK'] = '1' + + @triton.jit + def vecadd_kernel(a_ptr, b_ptr, output_ptr, n_elements, num_blocks): + pid = tl.program_id(axis=0) + block_start = pid * 128 * num_blocks + offsets = block_start + tl.arange(0, 128) + offsets_0 = block_start + tl.arange(2, 130) + for _ in tl.range(0, num_blocks): + mask = offsets < n_elements + x = tl.load(a_ptr + offsets, mask=mask) + x_0 = tl.load(a_ptr + offsets_0, mask=mask) + output = x + x_0 + tl.store(output_ptr + offsets, output, mask=mask) + offsets += 128 + offsets_0 += 128 + + triton.compile( + triton.compiler.ASTSource( + fn=vecadd_kernel, signature={ + 0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, + constants={})) + stdout, stderr = capfd.readouterr() + + # TODO: to fix this kernel + # assert "remark: Warning: SWP fails due to loop distance is greater than" in stdout, "expect SWP failure remark" + # assert "note: see current operation:" in captured.err + # assert "numstages in loop" in captured.err + os.environ['MLIR_ENABLE_REMARK'] = '0' + + +def test_remark_swp_outerloop(capfd): + os.environ['MLIR_ENABLE_REMARK'] = '1' + + @triton.jit + def binary_elemwise_kernel_2d(x_ptr, y_ptr, o_ptr, num_elem, shape_x0, shape_x1, stride): + pid = tl.program_id(axis=0) + start_elem = min(pid * num_elem, shape_x0) + end_elem = min(pid * num_elem + num_elem, shape_x0) + + # The outer-loop will not be pipelined in the precondition check, + # There will be warning for the outer loop. + for i in range(start_elem, end_elem): + x_addr = x_ptr + i * stride + y_addr = y_ptr + i * stride + o_addr = o_ptr + i * stride + + elem_offset = tl.arange(0, 128) + x_blk_ptr = x_addr + elem_offset + y_blk_ptr = y_addr + elem_offset + o_blk_ptr = o_addr + elem_offset + + block_start = 0 + for _ in range(0, shape_x1, 128): + elem_offset = block_start + tl.arange(0, 128) + mask = elem_offset < shape_x1 + + x = tl.load(x_blk_ptr, mask=mask) + y = tl.load(y_blk_ptr, mask=mask) + + output = x + y + + tl.store(o_blk_ptr, output, mask=mask) + + x_blk_ptr += 128 + y_blk_ptr += 128 + o_blk_ptr += 128 + block_start += 128 + + triton.compile( + triton.compiler.ASTSource( + fn=binary_elemwise_kernel_2d, signature={ + 0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32'}, constants={})) + _, err = capfd.readouterr() + + assert "remark: Warning: SWP fails on the outer loop" in err, "expect SWP failure remark" + os.environ['MLIR_ENABLE_REMARK'] = '0' diff --git a/test/TritonGPU/swp-warning.mlir b/test/TritonGPU/swp-warning.mlir deleted file mode 100644 index b262e9dd91c1..000000000000 --- a/test/TritonGPU/swp-warning.mlir +++ /dev/null @@ -1,40 +0,0 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -dump-swp-failure | FileCheck %s - -// CHECK-LABEL: @dont_pipeline_128x1 -// CHECK-NOT: local_load{{.*}}128x1 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @dont_pipeline_128x1(%arg6: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false, swp = true} { - %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %c128_i32 = arith.constant 128 : i32 - %c0_i32 = arith.constant 0 : i32 - %c64_i32 = arith.constant 64 : i32 - %cst_4 = arith.constant dense<-1.000000e+30> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - - %99:1 = scf.for %arg25 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg31 = %cst_4) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) : i32 { - %94 = tt.splat %arg6 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %151 = tt.load %94 : tensor<128x1x!tt.ptr, #blocked> - %161 = triton_gpu.convert_layout %151 : tensor<128x1xi32, #blocked> -> tensor<128x1xi32, #mma> - %162 = tt.broadcast %161 : tensor<128x1xi32, #mma> -> tensor<128x64xi32, #mma> - %170 = arith.sitofp %162 : tensor<128x64xi32, #mma> to tensor<128x64xf32, #mma> - - %173 = "tt.reduce"(%170) <{axis = 1 : i32}> ({ - ^bb0(%arg33: f32, %arg34: f32): - %207 = arith.maxnumf %arg33, %arg34 : f32 - tt.reduce.return %207 : f32 - }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %175 = arith.maxnumf %arg31, %173 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - - %201 = arith.truncf %170 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - %202 = triton_gpu.convert_layout %201 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - - %192 = arith.constant dense<0.> : tensor<128x64xf32, #mma> - %203 = arith.constant dense<0.> : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %204 = tt.dot %202, %203, %192 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> - - scf.yield %175 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - } - tt.return - } -} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 052930c2829d..75434f2f7d95 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -173,6 +173,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, LLVM_DEBUG(llvm::dbgs() << " vec = " << vec << '\n'); } + if (vec == 1 && numElems > 1) + op->emitRemark() << "Warning: vectorization fails due to vec is 1"; + // Get the LLVM values for pointers auto ptrElems = unpackLLElements(loc, llPtr, rewriter); assert(ptrElems.size() == numElems); @@ -388,6 +391,9 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, vec = std::min(vec, maskAlign); } + if (vec == 1 && elemsPerThread > 1) + op->emitRemark() << "Warning: vectorization fails due to vec is 1"; + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); @@ -522,6 +528,9 @@ struct AtomicCASOpConversion vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); } + if (vec == 1 && elemsPerThread > 1) + moduleOp->emitRemark() << "Warning: vectorization fails due to vec is 1"; + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); @@ -645,6 +654,10 @@ struct AtomicRMWOpConversion // mask numElems = tensorTy.getNumElements(); } + + if (vec == 1 && numElems > 1) + moduleOp->emitRemark() << "Warning: vectorization fails due to vec is 1"; + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); From 454693d7cb88be756de7210ef4115c0342413356 Mon Sep 17 00:00:00 2001 From: Zeng Wu Date: Wed, 17 Jul 2024 22:43:56 -0700 Subject: [PATCH 3/6] Emit perf warning for the failure of SWP, vectorization --- .../Pipeliner/SoftwarePipeliner.cpp | 1 + python/test/unit/test_perf_warning.py | 92 ++++++++----------- third_party/nvidia/backend/compiler.py | 5 + .../LoadStoreOpToLLVM.cpp | 12 ++- 4 files changed, 50 insertions(+), 60 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp index 3a9a7e68094f..351f2753be03 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -127,6 +127,7 @@ struct PipelinePass : public impl::TritonGPUPipelineBase { return; } + llvm::SmallSetVector outerLoops; for (scf::ForOp forOp : loops) { auto outerLoop = dyn_cast(forOp->getParentOp()); diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 8504a120eafa..38f93f0e5f0d 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -3,7 +3,7 @@ import os import pytest import torch - +import tempfile def is_perf_warning_enabled(): return os.environ.get('MLIR_ENABLE_REMARK', '0') == '1' @@ -47,64 +47,44 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, os.environ['MLIR_ENABLE_REMARK'] = '0' -def test_remark_swp_1stage(capfd): +def test_remark_vectorization(capfd): os.environ['MLIR_ENABLE_REMARK'] = '1' @triton.jit - def vecadd_kernel(a_ptr, b_ptr, output_ptr, n_elements, num_blocks): - pid = tl.program_id(axis=0) - block_start = pid * 128 * num_blocks - offsets = block_start + tl.arange(0, 128) - for _ in tl.range(0, num_blocks, num_stages=1): - mask = offsets < n_elements - x = tl.load(a_ptr + offsets, mask=mask) - y = tl.load(b_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - offsets += 128 - + def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 28311552 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex % 9 + x2 = (xindex // 3456) % 512 + x1 = (xindex // 9) % 384 + x4 = xindex + tmp0 = tl.load(in_ptr0 + (x2 + (512*x0)), None, eviction_policy='evict_last') + tmp1 = tmp0 + 520 + tmp2 = tmp0 < 0 + tmp3 = tl.where(tmp2, tmp1, tmp0) + tmp9 = (-4) + tmp3 + tmp12 = tl.full([1], 512, tl.int64) + tmp14 = tmp9 < tmp12 + tmp16 = tl.load(in_ptr3 + (x1), tmp14, eviction_policy='evict_last', other=0.0) + tmp18 = tmp16.to(tl.float32) + tmp19 = tmp18.to(tl.float32) + tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype) + tmp21 = tl.where(tmp14, tmp19, tmp20) + tmp22 = tmp21.to(tl.float32) + tl.store(out_ptr0 + (x4), tmp22, None) + + XBLOCK = 1024 + XSIZE = 2048 + MAX_ELEM = 65536 triton.compile( triton.compiler.ASTSource( - fn=vecadd_kernel, signature={ - 0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, - constants={})) - _, err = capfd.readouterr() - - assert "remark: Warning: SWP fails. There is no loop with num_stages greater than 1" in err, "expect SWP failure remark" - # assert "note: see current operation:" in captured.err - # assert "numstages in loop" in captured.err - os.environ['MLIR_ENABLE_REMARK'] = '0' - - -def test_remark_swp_dep_distance(capfd): - os.environ['MLIR_ENABLE_REMARK'] = '1' + fn=ldst_vec, signature={0: '*i64', 1: '*i64', 2: '*fp16', 3: '*fp32', 4: '*fp16', 5: 'i32'}, + constants={"XBLOCK": XBLOCK}), options={"num_warps": 1}) - @triton.jit - def vecadd_kernel(a_ptr, b_ptr, output_ptr, n_elements, num_blocks): - pid = tl.program_id(axis=0) - block_start = pid * 128 * num_blocks - offsets = block_start + tl.arange(0, 128) - offsets_0 = block_start + tl.arange(2, 130) - for _ in tl.range(0, num_blocks): - mask = offsets < n_elements - x = tl.load(a_ptr + offsets, mask=mask) - x_0 = tl.load(a_ptr + offsets_0, mask=mask) - output = x + x_0 - tl.store(output_ptr + offsets, output, mask=mask) - offsets += 128 - offsets_0 += 128 - - triton.compile( - triton.compiler.ASTSource( - fn=vecadd_kernel, signature={ - 0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, - constants={})) - stdout, stderr = capfd.readouterr() - - # TODO: to fix this kernel - # assert "remark: Warning: SWP fails due to loop distance is greater than" in stdout, "expect SWP failure remark" - # assert "note: see current operation:" in captured.err - # assert "numstages in loop" in captured.err + _, err = capfd.readouterr() + assert "remark: Warning: vectorization fails" in err, "expect vectorization failure remark" os.environ['MLIR_ENABLE_REMARK'] = '0' @@ -147,9 +127,9 @@ def binary_elemwise_kernel_2d(x_ptr, y_ptr, o_ptr, num_elem, shape_x0, shape_x1, block_start += 128 triton.compile( - triton.compiler.ASTSource( - fn=binary_elemwise_kernel_2d, signature={ - 0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32'}, constants={})) + triton.compiler.ASTSource(fn=binary_elemwise_kernel_2d, signature={ + 0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32'}, constants={})) + _, err = capfd.readouterr() assert "remark: Warning: SWP fails on the outer loop" in err, "expect SWP failure remark" diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 7398ea30b691..f74176742b15 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -229,6 +229,11 @@ def make_llir(src, metadata, options, capability): # TritonGPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() + # Set up Diagnostic + if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1": + srcMgr = llvm.source_mgr() + diag = ir.source_mgr_diag(srcMgr, mod.context) + mod.context.printOpOnDiagnostic(True) nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.convert.add_scf_to_cf(pm) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 75434f2f7d95..78dc7252e6d7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -174,7 +174,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, } if (vec == 1 && numElems > 1) - op->emitRemark() << "Warning: vectorization fails due to vec is 1"; + op->emitRemark() << "Warning: vectorization fails vec = " + << vec << " numElems = " << numElems << "\n"; // Get the LLVM values for pointers auto ptrElems = unpackLLElements(loc, llPtr, rewriter); @@ -392,7 +393,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, } if (vec == 1 && elemsPerThread > 1) - op->emitRemark() << "Warning: vectorization fails due to vec is 1"; + op->emitRemark() << "Warning: vectorization fails vec = " + << vec << " elemsPerThread = " << elemsPerThread; Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); const size_t dtsize = @@ -529,7 +531,8 @@ struct AtomicCASOpConversion } if (vec == 1 && elemsPerThread > 1) - moduleOp->emitRemark() << "Warning: vectorization fails due to vec is 1"; + op->emitRemark() << "Warning: vectorization fails vec = " + << vec << " elemsPerThread = " << elemsPerThread; Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); @@ -656,7 +659,8 @@ struct AtomicRMWOpConversion } if (vec == 1 && numElems > 1) - moduleOp->emitRemark() << "Warning: vectorization fails due to vec is 1"; + op->emitRemark() << "Warning: vectorization fails vec = " + << vec << " numElems = " << numElems; Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); From 76b8c3a2335ecb02024d4e31a4f87e0d36dda1a4 Mon Sep 17 00:00:00 2001 From: Zeng Wu Date: Wed, 17 Jul 2024 22:43:56 -0700 Subject: [PATCH 4/6] Emit perf warning for the failure of SWP, vectorization --- .../Pipeliner/SoftwarePipeliner.cpp | 18 +++++---------- .../LoadStoreOpToLLVM.cpp | 22 ++++++++++++------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp index 351f2753be03..56b1ba42e2c8 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -41,11 +41,8 @@ static bool preCondition(scf::ForOp forOp) { [](Value operand) { Operation *def = operand.getDefiningOp(); return !def; - })) { - forOp->emitRemark() << "Warning: SWP fails due to loop distance is greater than 1"; + })) return false; - } - // Don't pipeline outer loops. if (forOp ->walk([&](Operation *op) { @@ -55,7 +52,7 @@ static bool preCondition(scf::ForOp forOp) { return WalkResult::interrupt(); return WalkResult::advance(); }) - .wasInterrupted()) { + .wasInterrupted()) { forOp->emitRemark() << "Warning: SWP fails on the outer loop"; return false; } @@ -107,10 +104,9 @@ struct PipelinePass : public impl::TritonGPUPipelineBase { // global control. if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) return numStages; - return mlir::cast( - forOp->getAttr(mlir::triton::kNumStagesAttrName)) - .getInt(); + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); } void runOnOperation() override { @@ -121,12 +117,8 @@ struct PipelinePass : public impl::TritonGPUPipelineBase { loops.push_back(forOp); }); - if (loops.empty()) { - auto op = getOperation(); - op->emitRemark() << "Warning: SWP fails. There is no loop with num_stages greater than 1"; + if (loops.empty()) return; - } - llvm::SmallSetVector outerLoops; for (scf::ForOp forOp : loops) { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 07bd16501f96..d120939dceab 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -173,9 +173,12 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, LLVM_DEBUG(llvm::dbgs() << " vec = " << vec << '\n'); } - if (vec == 1 && numElems > 1) - op->emitRemark() << "Warning: vectorization fails vec = " - << vec << " numElems = " << numElems << "\n"; + if (vec == 1 && numElems > 1) { + auto maskStr = !llMask ? "no mask" : std::to_string(getMaskAlignment(mask)); + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " numElems = " << numElems + << " mask is " << maskStr << "\n"; + } // Get the LLVM values for pointers auto ptrElems = unpackLLElements(loc, llPtr, rewriter); @@ -392,9 +395,12 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, vec = std::min(vec, maskAlign); } - if (vec == 1 && elemsPerThread > 1) - op->emitRemark() << "Warning: vectorization fails vec = " - << vec << " elemsPerThread = " << elemsPerThread; + if (vec == 1 && elemsPerThread > 1) { + auto maskStr = !llMask ? "no mask" : std::to_string(getMaskAlignment(op.getMask())); + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " elemsPerThread = " << elemsPerThread + << " mask is " << maskStr << "\n"; + } Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); const size_t dtsize = @@ -531,8 +537,8 @@ struct AtomicCASOpConversion } if (vec == 1 && elemsPerThread > 1) - op->emitRemark() << "Warning: vectorization fails vec = " - << vec << " elemsPerThread = " << elemsPerThread; + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " elemsPerThread = " << elemsPerThread << "\n"; Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); From 93508380e4b918b84317a78011d33ea861e2b911 Mon Sep 17 00:00:00 2001 From: Zeng Wu Date: Wed, 17 Jul 2024 22:43:56 -0700 Subject: [PATCH 5/6] Emit perf warning for the failure of SWP, vectorization --- python/test/unit/test_perf_warning.py | 36 +++++++++---------- .../LoadStoreOpToLLVM.cpp | 25 ++++++++----- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 38f93f0e5f0d..b08983aa7df2 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -3,7 +3,7 @@ import os import pytest import torch -import tempfile + def is_perf_warning_enabled(): return os.environ.get('MLIR_ENABLE_REMARK', '0') == '1' @@ -13,7 +13,7 @@ def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" -def test_remark_mma(capfd): +def test_mma_remark(capfd): if is_cuda(): capability = torch.cuda.get_device_capability() if capability[0] < 9: @@ -48,26 +48,24 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, def test_remark_vectorization(capfd): - os.environ['MLIR_ENABLE_REMARK'] = '1' + os.environ["MLIR_ENABLE_REMARK"] = "1" @triton.jit - def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, XBLOCK: tl.constexpr): - xnumel = 28311552 + def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr): xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] - xmask = xindex < xnumel x0 = xindex % 9 x2 = (xindex // 3456) % 512 x1 = (xindex // 9) % 384 x4 = xindex - tmp0 = tl.load(in_ptr0 + (x2 + (512*x0)), None, eviction_policy='evict_last') + tmp0 = tl.load(in_ptr0 + (x2 + (512 * x0)), None, eviction_policy="evict_last") tmp1 = tmp0 + 520 tmp2 = tmp0 < 0 tmp3 = tl.where(tmp2, tmp1, tmp0) tmp9 = (-4) + tmp3 tmp12 = tl.full([1], 512, tl.int64) tmp14 = tmp9 < tmp12 - tmp16 = tl.load(in_ptr3 + (x1), tmp14, eviction_policy='evict_last', other=0.0) + tmp16 = tl.load(in_ptr3 + (x1), tmp14, eviction_policy="evict_last", other=0.0) tmp18 = tmp16.to(tl.float32) tmp19 = tmp18.to(tl.float32) tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype) @@ -76,20 +74,17 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, XBLOCK: tl.co tl.store(out_ptr0 + (x4), tmp22, None) XBLOCK = 1024 - XSIZE = 2048 - MAX_ELEM = 65536 triton.compile( - triton.compiler.ASTSource( - fn=ldst_vec, signature={0: '*i64', 1: '*i64', 2: '*fp16', 3: '*fp32', 4: '*fp16', 5: 'i32'}, - constants={"XBLOCK": XBLOCK}), options={"num_warps": 1}) + triton.compiler.ASTSource(fn=ldst_vec, signature={0: '*i64', 1: '*i64', 2: '*fp16', 3: '*fp32', 4: '*fp16'}, + constants={"XBLOCK": XBLOCK}), options={"num_warps": 1}) _, err = capfd.readouterr() - assert "remark: Warning: vectorization fails" in err, "expect vectorization failure remark" - os.environ['MLIR_ENABLE_REMARK'] = '0' + assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark" + os.environ["MLIR_ENABLE_REMARK"] = "0" def test_remark_swp_outerloop(capfd): - os.environ['MLIR_ENABLE_REMARK'] = '1' + os.environ["MLIR_ENABLE_REMARK"] = "1" @triton.jit def binary_elemwise_kernel_2d(x_ptr, y_ptr, o_ptr, num_elem, shape_x0, shape_x1, stride): @@ -127,10 +122,11 @@ def binary_elemwise_kernel_2d(x_ptr, y_ptr, o_ptr, num_elem, shape_x0, shape_x1, block_start += 128 triton.compile( - triton.compiler.ASTSource(fn=binary_elemwise_kernel_2d, signature={ - 0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32'}, constants={})) + triton.compiler.ASTSource( + fn=binary_elemwise_kernel_2d, + signature={0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32'}, constants={})) _, err = capfd.readouterr() - assert "remark: Warning: SWP fails on the outer loop" in err, "expect SWP failure remark" - os.environ['MLIR_ENABLE_REMARK'] = '0' + assert ("remark: Warning: SWP fails on the outer loop" in err), "expect SWP failure remark" + os.environ["MLIR_ENABLE_REMARK"] = "0" diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 35d5b4fd1ba6..faa7b7d2e935 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -166,6 +166,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, typeConverter->convertType(getElementTypeOrSelf(op.getType())); unsigned vec = getVectorSize(ptr); unsigned numElems = getTotalElemsPerThread(ptr.getType()); + unsigned vecOrig = vec; if (llMask) { LLVM_DEBUG(DBGS() << "vec = " << vec << " mask_alignment = " << getMaskAlignment(mask)); @@ -174,12 +175,12 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, } if (vec == 1 && numElems > 1) { - auto maskStr = !llMask ? "no mask" : std::to_string(getMaskAlignment(mask)); + int maskValue = !llMask ? -1 : getMaskAlignment(mask); op->emitRemark() << "Warning: vectorization fails vec = " << vec - << " numElems = " << numElems - << " mask is " << maskStr << "\n"; + << " origin vec = " << vecOrig + << " numElems = " << numElems << " mask is " << maskValue + << "\n"; } - // Get the LLVM values for pointers auto ptrElems = unpackLLElements(loc, llPtr, rewriter); assert(ptrElems.size() == numElems); @@ -385,6 +386,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, assert(ptrElems.size() == valueElems.size()); // Determine the vectorization size + unsigned vecOrig = vec; SmallVector maskElems; if (llMask) { Value mask = op.getMask(); @@ -396,10 +398,11 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, } if (vec == 1 && elemsPerThread > 1) { - auto maskStr = !llMask ? "no mask" : std::to_string(getMaskAlignment(op.getMask())); + int mask = !llMask ? -1 : getMaskAlignment(op.getMask()); op->emitRemark() << "Warning: vectorization fails vec = " << vec - << " elemsPerThread = " << elemsPerThread - << " mask is " << maskStr << "\n"; + << " origin vec = " << vecOrig + << " elemsPerThread = " << elemsPerThread << " mask is " + << mask << "\n"; } Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); @@ -530,6 +533,7 @@ struct AtomicCASOpConversion auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); // vec = 1 for scalar auto vec = getVectorSize(op.getPtr()); + auto vecOrig = vec; // tensor if (tensorTy) { auto valTy = cast(op.getVal().getType()); @@ -538,6 +542,7 @@ struct AtomicCASOpConversion if (vec == 1 && elemsPerThread > 1) op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig << " elemsPerThread = " << elemsPerThread << "\n"; Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); @@ -657,6 +662,7 @@ struct AtomicRMWOpConversion auto elemsPerThread = getTotalElemsPerThread(val.getType()); // vec = 1, numElements = 1 for scalar auto vec = getVectorSize(ptr); + auto vecOrig = vec; int numElems = 1; // tensor if (tensorTy) { @@ -667,8 +673,9 @@ struct AtomicRMWOpConversion } if (vec == 1 && numElems > 1) - op->emitRemark() << "Warning: vectorization fails vec = " - << vec << " numElems = " << numElems; + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " numElems = " << numElems; Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); From 85a2f93eeaee3e06447681331b7d62d44a0f109a Mon Sep 17 00:00:00 2001 From: Zeng Wu Date: Wed, 17 Jul 2024 22:43:56 -0700 Subject: [PATCH 6/6] Emit perf warning for the failure of SWP, vectorization --- .../Pipeliner/SoftwarePipeliner.cpp | 5 +- python/test/unit/test_perf_warning.py | 73 +++---------------- .../LoadStoreOpToLLVM.cpp | 25 ++++--- 3 files changed, 28 insertions(+), 75 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp index 56b1ba42e2c8..0bdf600194d2 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -52,10 +52,9 @@ static bool preCondition(scf::ForOp forOp) { return WalkResult::interrupt(); return WalkResult::advance(); }) - .wasInterrupted()) { - forOp->emitRemark() << "Warning: SWP fails on the outer loop"; + .wasInterrupted()) return false; - } + return true; } diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 38f93f0e5f0d..8b793dd36095 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -3,7 +3,7 @@ import os import pytest import torch -import tempfile + def is_perf_warning_enabled(): return os.environ.get('MLIR_ENABLE_REMARK', '0') == '1' @@ -13,7 +13,7 @@ def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" -def test_remark_mma(capfd): +def test_mma_remark(capfd): if is_cuda(): capability = torch.cuda.get_device_capability() if capability[0] < 9: @@ -48,26 +48,24 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, def test_remark_vectorization(capfd): - os.environ['MLIR_ENABLE_REMARK'] = '1' + os.environ["MLIR_ENABLE_REMARK"] = "1" @triton.jit - def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, XBLOCK: tl.constexpr): - xnumel = 28311552 + def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr): xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] - xmask = xindex < xnumel x0 = xindex % 9 x2 = (xindex // 3456) % 512 x1 = (xindex // 9) % 384 x4 = xindex - tmp0 = tl.load(in_ptr0 + (x2 + (512*x0)), None, eviction_policy='evict_last') + tmp0 = tl.load(in_ptr0 + (x2 + (512 * x0)), None, eviction_policy="evict_last") tmp1 = tmp0 + 520 tmp2 = tmp0 < 0 tmp3 = tl.where(tmp2, tmp1, tmp0) tmp9 = (-4) + tmp3 tmp12 = tl.full([1], 512, tl.int64) tmp14 = tmp9 < tmp12 - tmp16 = tl.load(in_ptr3 + (x1), tmp14, eviction_policy='evict_last', other=0.0) + tmp16 = tl.load(in_ptr3 + (x1), tmp14, eviction_policy="evict_last", other=0.0) tmp18 = tmp16.to(tl.float32) tmp19 = tmp18.to(tl.float32) tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype) @@ -76,61 +74,10 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, XBLOCK: tl.co tl.store(out_ptr0 + (x4), tmp22, None) XBLOCK = 1024 - XSIZE = 2048 - MAX_ELEM = 65536 - triton.compile( - triton.compiler.ASTSource( - fn=ldst_vec, signature={0: '*i64', 1: '*i64', 2: '*fp16', 3: '*fp32', 4: '*fp16', 5: 'i32'}, - constants={"XBLOCK": XBLOCK}), options={"num_warps": 1}) - - _, err = capfd.readouterr() - assert "remark: Warning: vectorization fails" in err, "expect vectorization failure remark" - os.environ['MLIR_ENABLE_REMARK'] = '0' - - -def test_remark_swp_outerloop(capfd): - os.environ['MLIR_ENABLE_REMARK'] = '1' - - @triton.jit - def binary_elemwise_kernel_2d(x_ptr, y_ptr, o_ptr, num_elem, shape_x0, shape_x1, stride): - pid = tl.program_id(axis=0) - start_elem = min(pid * num_elem, shape_x0) - end_elem = min(pid * num_elem + num_elem, shape_x0) - - # The outer-loop will not be pipelined in the precondition check, - # There will be warning for the outer loop. - for i in range(start_elem, end_elem): - x_addr = x_ptr + i * stride - y_addr = y_ptr + i * stride - o_addr = o_ptr + i * stride - - elem_offset = tl.arange(0, 128) - x_blk_ptr = x_addr + elem_offset - y_blk_ptr = y_addr + elem_offset - o_blk_ptr = o_addr + elem_offset - - block_start = 0 - for _ in range(0, shape_x1, 128): - elem_offset = block_start + tl.arange(0, 128) - mask = elem_offset < shape_x1 - - x = tl.load(x_blk_ptr, mask=mask) - y = tl.load(y_blk_ptr, mask=mask) - - output = x + y - - tl.store(o_blk_ptr, output, mask=mask) - - x_blk_ptr += 128 - y_blk_ptr += 128 - o_blk_ptr += 128 - block_start += 128 - triton.compile( - triton.compiler.ASTSource(fn=binary_elemwise_kernel_2d, signature={ - 0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32'}, constants={})) + triton.compiler.ASTSource(fn=ldst_vec, signature={0: '*i64', 1: '*i64', 2: '*fp16', 3: '*fp32', 4: '*fp16'}, + constants={"XBLOCK": XBLOCK}), options={"num_warps": 1}) _, err = capfd.readouterr() - - assert "remark: Warning: SWP fails on the outer loop" in err, "expect SWP failure remark" - os.environ['MLIR_ENABLE_REMARK'] = '0' + assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark" + os.environ["MLIR_ENABLE_REMARK"] = "0" diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 35d5b4fd1ba6..faa7b7d2e935 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -166,6 +166,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, typeConverter->convertType(getElementTypeOrSelf(op.getType())); unsigned vec = getVectorSize(ptr); unsigned numElems = getTotalElemsPerThread(ptr.getType()); + unsigned vecOrig = vec; if (llMask) { LLVM_DEBUG(DBGS() << "vec = " << vec << " mask_alignment = " << getMaskAlignment(mask)); @@ -174,12 +175,12 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, } if (vec == 1 && numElems > 1) { - auto maskStr = !llMask ? "no mask" : std::to_string(getMaskAlignment(mask)); + int maskValue = !llMask ? -1 : getMaskAlignment(mask); op->emitRemark() << "Warning: vectorization fails vec = " << vec - << " numElems = " << numElems - << " mask is " << maskStr << "\n"; + << " origin vec = " << vecOrig + << " numElems = " << numElems << " mask is " << maskValue + << "\n"; } - // Get the LLVM values for pointers auto ptrElems = unpackLLElements(loc, llPtr, rewriter); assert(ptrElems.size() == numElems); @@ -385,6 +386,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, assert(ptrElems.size() == valueElems.size()); // Determine the vectorization size + unsigned vecOrig = vec; SmallVector maskElems; if (llMask) { Value mask = op.getMask(); @@ -396,10 +398,11 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, } if (vec == 1 && elemsPerThread > 1) { - auto maskStr = !llMask ? "no mask" : std::to_string(getMaskAlignment(op.getMask())); + int mask = !llMask ? -1 : getMaskAlignment(op.getMask()); op->emitRemark() << "Warning: vectorization fails vec = " << vec - << " elemsPerThread = " << elemsPerThread - << " mask is " << maskStr << "\n"; + << " origin vec = " << vecOrig + << " elemsPerThread = " << elemsPerThread << " mask is " + << mask << "\n"; } Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); @@ -530,6 +533,7 @@ struct AtomicCASOpConversion auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); // vec = 1 for scalar auto vec = getVectorSize(op.getPtr()); + auto vecOrig = vec; // tensor if (tensorTy) { auto valTy = cast(op.getVal().getType()); @@ -538,6 +542,7 @@ struct AtomicCASOpConversion if (vec == 1 && elemsPerThread > 1) op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig << " elemsPerThread = " << elemsPerThread << "\n"; Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); @@ -657,6 +662,7 @@ struct AtomicRMWOpConversion auto elemsPerThread = getTotalElemsPerThread(val.getType()); // vec = 1, numElements = 1 for scalar auto vec = getVectorSize(ptr); + auto vecOrig = vec; int numElems = 1; // tensor if (tensorTy) { @@ -667,8 +673,9 @@ struct AtomicRMWOpConversion } if (vec == 1 && numElems > 1) - op->emitRemark() << "Warning: vectorization fails vec = " - << vec << " numElems = " << numElems; + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " numElems = " << numElems; Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);