diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 871bc6ba294b..8b793dd36095 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -45,3 +45,39 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, assert "remark: Warning: can't use MMA V3 for the dot op" in captured.err, "expect MMA V3 remark" assert "note: see current operation:" in captured.err os.environ['MLIR_ENABLE_REMARK'] = '0' + + +def test_remark_vectorization(capfd): + os.environ["MLIR_ENABLE_REMARK"] = "1" + + @triton.jit + def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + x0 = xindex % 9 + x2 = (xindex // 3456) % 512 + x1 = (xindex // 9) % 384 + x4 = xindex + tmp0 = tl.load(in_ptr0 + (x2 + (512 * x0)), None, eviction_policy="evict_last") + tmp1 = tmp0 + 520 + tmp2 = tmp0 < 0 + tmp3 = tl.where(tmp2, tmp1, tmp0) + tmp9 = (-4) + tmp3 + tmp12 = tl.full([1], 512, tl.int64) + tmp14 = tmp9 < tmp12 + tmp16 = tl.load(in_ptr3 + (x1), tmp14, eviction_policy="evict_last", other=0.0) + tmp18 = tmp16.to(tl.float32) + tmp19 = tmp18.to(tl.float32) + tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype) + tmp21 = tl.where(tmp14, tmp19, tmp20) + tmp22 = tmp21.to(tl.float32) + tl.store(out_ptr0 + (x4), tmp22, None) + + XBLOCK = 1024 + triton.compile( + triton.compiler.ASTSource(fn=ldst_vec, signature={0: '*i64', 1: '*i64', 2: '*fp16', 3: '*fp32', 4: '*fp16'}, + constants={"XBLOCK": XBLOCK}), options={"num_warps": 1}) + + _, err = capfd.readouterr() + assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark" + os.environ["MLIR_ENABLE_REMARK"] = "0" diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 99eeae06f721..ea1d79f9ba93 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -237,6 +237,11 @@ def make_llir(src, metadata, options, capability): # TritonGPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() + # Set up Diagnostic + if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1": + srcMgr = llvm.source_mgr() + diag = ir.source_mgr_diag(srcMgr, mod.context) + mod.context.printOpOnDiagnostic(True) nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.convert.add_scf_to_cf(pm) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 25a176dd1418..27d7dd69b8c0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -166,6 +166,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, typeConverter->convertType(getElementTypeOrSelf(op.getType())); unsigned vec = getVectorSize(ptr); unsigned numElems = getTotalElemsPerThread(ptr.getType()); + unsigned vecOrig = vec; if (llMask) { LLVM_DEBUG(DBGS() << "vec = " << vec << " mask_alignment = " << getMaskAlignment(mask)); @@ -173,6 +174,13 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, LLVM_DEBUG(llvm::dbgs() << " vec = " << vec << '\n'); } + if (vec == 1 && numElems > 1) { + int maskValue = !llMask ? -1 : getMaskAlignment(mask); + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " numElems = " << numElems << " mask is " << maskValue + << "\n"; + } // Get the LLVM values for pointers auto ptrElems = unpackLLElements(loc, llPtr, rewriter); assert(ptrElems.size() == numElems); @@ -380,6 +388,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, assert(ptrElems.size() == valueElems.size()); // Determine the vectorization size + unsigned vecOrig = vec; SmallVector maskElems; if (llMask) { Value mask = op.getMask(); @@ -390,6 +399,14 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, vec = std::min(vec, maskAlign); } + if (vec == 1 && elemsPerThread > 1) { + int mask = !llMask ? -1 : getMaskAlignment(op.getMask()); + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " elemsPerThread = " << elemsPerThread << " mask is " + << mask << "\n"; + } + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); @@ -518,12 +535,18 @@ struct AtomicCASOpConversion auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); // vec = 1 for scalar auto vec = getVectorSize(op.getPtr()); + auto vecOrig = vec; // tensor if (tensorTy) { auto valTy = cast(op.getVal().getType()); vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); } + if (vec == 1 && elemsPerThread > 1) + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " elemsPerThread = " << elemsPerThread << "\n"; + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); @@ -641,6 +664,7 @@ struct AtomicRMWOpConversion auto elemsPerThread = getTotalElemsPerThread(val.getType()); // vec = 1, numElements = 1 for scalar auto vec = getVectorSize(ptr); + auto vecOrig = vec; int numElems = 1; // tensor if (tensorTy) { @@ -649,6 +673,12 @@ struct AtomicRMWOpConversion // mask numElems = tensorTy.getNumElements(); } + + if (vec == 1 && numElems > 1) + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " numElems = " << numElems; + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec);