diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 612404fd6dab..cb2a7585b42d 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -1042,7 +1042,7 @@ struct ConvertLayoutOpConversion } return res; } -}; // namespace triton::gpu::ConvertLayoutOp> +}; void populateConvertLayoutOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 087f971290c9..394c88573b3c 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2,6 +2,8 @@ #include "Utility.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +namespace { + using namespace mlir; using namespace mlir::triton; @@ -10,6 +12,23 @@ using ::mlir::LLVM::getSRegValue; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; +Value llGetPid(int axis, Location loc, ModuleOp moduleOp, + ConversionPatternRewriter &rewriter) { + assert(axis >= 0); + assert(axis < 3); + assert(moduleOp); + + // It is not easy to get the compute capability here, so we use numCTAs to + // decide the semantic of GetProgramIdOp. If numCTAs = 1, then + // GetProgramIdOp is converted to "%ctaid", otherwise it is converted to + // "%clusterid". + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + + std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid."; + sreg.append(1, 'x' + axis); // 0 -> 'x', 1 -> 'y', 2 -> 'z' + return getSRegValue(rewriter, loc, sreg); +} + struct ReturnOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -91,6 +110,12 @@ struct BroadcastOpConversion } }; +// The input print op contains: +// - a "prefix" (string) specified by the user, and +// - one or more "operands" (tensors). +// +// For each operand, we print all of the values contained in this GPU thread, +// one per line, along with the index of the value in its tensor. struct PrintOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< @@ -100,45 +125,169 @@ struct PrintOpConversion matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - SmallVector operands; + Value prefixStr = + LLVM::addStringToModule(loc, rewriter, "printfPrefix_", op.getPrefix()); + + auto getPid = [&](int axis) { + return llGetPid(axis, loc, op->getParentOfType(), rewriter); + }; + std::array pid = {getPid(0), getPid(1), getPid(2)}; + + // Simple printf of a string without any tensors. + if (op.getNumOperands() == 0) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "pid (" << getFormatSubstr(pid[0]) << ", " + << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")%s"; + llPrintf(formatStr, {pid[0], pid[1], pid[2], prefixStr}, rewriter); + return success(); + } + for (size_t i = 0; i < op.getNumOperands(); i++) { - auto sub_operands = getTypeConverter()->unpackLLElements( + // Elements of the tensor that are resident in this GPU thread. + auto elems = getTypeConverter()->unpackLLElements( loc, adaptor.getOperands()[i], rewriter, op.getOperand(i).getType()); - for (auto elem : sub_operands) { - operands.push_back(elem); + + // Get the indices of `elems` within the tensor. Note that if `elems` has + // an "interesting" layout, then these will not be in any particularly + // nice order. + + // Extract the shape of the tensor being printed and use it to figure out + // how many digits we need for each of the dimensions. + SmallVector dimWidths; + SmallVector> indices; + if (auto rankedTy = + op.getOperand(i).getType().dyn_cast()) { + indices = emitIndices(loc, rewriter, rankedTy.getEncoding(), rankedTy); + for (int64_t dim : rankedTy.getShape()) { + if (dim > 0) { + dimWidths.push_back(static_cast(std::ceil(std::log10(dim)))); + } else { + dimWidths.push_back(0); + } + } + } else { + // We're printing a scalar. + assert(elems.size() == 1); + indices.push_back({}); } - } - std::string formatStr; - llvm::raw_string_ostream os(formatStr); - os << op.getPrefix(); - if (!operands.empty()) { - os << getFormatSubstr(operands[0]); - } - for (size_t i = 1; i < operands.size(); ++i) { - os << ", " << getFormatSubstr(operands[i]); + if (!elems.empty()) { + printTensor(prefixStr, /*operand=*/i, + /*numOperands=*/op.getNumOperands(), elems, pid, indices, + dimWidths, rewriter); + } } - llPrintf(formatStr, operands, rewriter); rewriter.eraseOp(op); return success(); } - std::string getFormatSubstr(Value value) const { + void printTensor(Value prefixStr, size_t operand, size_t numOperands, + ArrayRef elems, std::array pid, + ArrayRef> indices, + ArrayRef dimWidths, + ConversionPatternRewriter &rewriter) const { + assert(!elems.empty()); + assert(elems.size() == indices.size()); + assert(dimWidths.size() == indices.front().size()); + + size_t rank = dimWidths.size(); + + // Format is: + // pid (, , ) idx (, , ...) (operand ) + // where we leave off "(operand )" if there's only one operand. + // + // The Python wrapper munges `prefix` so that it prints nicely (e.g. starts + // with " " and ends with ": "). + + Value formatStrValue; + for (int i = 0; i < elems.size(); i++) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + + // nvptx printf can only accept 32 args; if we pass more than that, it + // will print garbage for the trailing args. + constexpr int kMaxPrintfOperands = 32; + SmallVector printfOperands; + + // TODO(jlebar): We really should pad the pid, but because the max pid is + // not known at compile-time, this would require nontrivial device-side + // work. + os << "pid ("; + for (int j = 0; j < pid.size(); j++) { + if (j != 0) { + os << ", "; + } + os << getFormatSubstr(pid[j]); + printfOperands.push_back(pid[j]); + } + os << ") "; + + // If `rank` is large enough, we could end up exceeding + // kMaxPrintfOperands. In that case, just truncate the index. + // (Subtract 2 because we're going to add two operands after the index.) + int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2; + + os << "idx ("; + const auto &index = indices[i]; + for (size_t dim = 0; dim < index.size(); dim++) { + if (dim != 0) { + os << ", "; + } + if (dim == maxAllowedRank) { + os << "... (truncated)"; + break; + } + os << getFormatSubstr(index[dim], /*width=*/dimWidths[dim]); + printfOperands.push_back(index[dim]); + } + os << ")"; + + os << "%s"; + printfOperands.push_back(prefixStr); + + if (numOperands > 1) { + os << "(operand " << operand << ") "; + } + + auto elem = elems[i]; + os << getFormatSubstr(elem); + printfOperands.push_back(elem); + + // It's the same format string each iteration, but it's a lot easier if we + // construct the format string at the same time as we populate + // printfOperands. But we don't want to create BLOCK_SIZE duplicate + // strings, so we cache the Value. + if (i == 0) { + formatStrValue = llPrintf(formatStr, printfOperands, rewriter); + } else { + llPrintf(formatStrValue, printfOperands, rewriter); + } + } + } + + std::string getFormatSubstr(Value value, + std::optional width = std::nullopt) const { + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } + Type type = value.getType(); if (type.isa()) { - return "%p"; + return prefix + "p"; } else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { - return "%f"; + return prefix + "f"; } else if (type.isSignedInteger()) { if (type.getIntOrFloatBitWidth() == 64) - return "%lli"; + return prefix + "lli"; else - return "%i"; + return prefix + "i"; } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { if (type.getIntOrFloatBitWidth() == 64) - return "%llu"; + return prefix + "llu"; else - return "%u"; + return prefix + "u"; } assert(false && "not supported type"); return ""; @@ -194,9 +343,22 @@ struct PrintOpConversion return {newType, newOp}; } - static void llPrintf(StringRef msg, ValueRange args, - ConversionPatternRewriter &rewriter) { + // Returns a Value for the format string, which you can reuse. + static Value llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter) { assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), + rewriter, "printfFormat_", msgNewline); + llPrintf(msgValue, args, rewriter); + return msgValue; + } + + static void llPrintf(Value msg, ValueRange args, + ConversionPatternRewriter &rewriter) { Type int8Ptr = ptr_ty(i8_ty); auto *ctx = rewriter.getContext(); @@ -208,11 +370,6 @@ struct PrintOpConversion Value one = i32_val(1); Value zero = i32_val(0); - llvm::SmallString<64> msgNewline(msg); - msgNewline.push_back('\n'); - msgNewline.push_back('\0'); - Value prefixString = - LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline); Value bufferPtr = null(int8Ptr); SmallVector newArgs; @@ -240,7 +397,7 @@ struct PrintOpConversion bufferPtr = bitcast(allocated, int8Ptr); } - SmallVector operands{prefixString, bufferPtr}; + SmallVector operands{msg, bufferPtr}; call(funcOp, operands); } }; @@ -390,20 +547,8 @@ struct GetProgramIdOpConversion LogicalResult matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // It is not easy to get the compute capability here, so we use numCTAs to - // decide the semantic of GetProgramIdOp. If numCTAs = 1, then - // GetProgramIdOp is converted to "%ctaid", otherwise it is converted to - // "%clusterid". - auto moduleOp = op->getParentOfType(); - assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp"); - int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); - - Location loc = op->getLoc(); - assert(op.getAxisAsInt() < 3); - std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid."; - sreg.append(1, 'x' + op.getAxisAsInt()); // 0 -> 'x', 1 -> 'y', 2 -> 'z' - - Value programId = getSRegValue(rewriter, loc, sreg); + Value programId = llGetPid(op.getAxisAsInt(), op->getLoc(), + op->getParentOfType(), rewriter); rewriter.replaceOp(op, programId); return success(); } @@ -685,6 +830,10 @@ struct AsyncBulkCommitGroupOpConversion } }; +} // namespace + +namespace mlir::triton { + void populateTritonGPUToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, @@ -710,3 +859,5 @@ void populateTritonGPUToLLVMPatterns( patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } + +} // namespace mlir::triton diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h index 9019073584c0..b49791543c21 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h @@ -6,6 +6,8 @@ using namespace mlir; using namespace mlir::triton; +namespace mlir::triton { + void populateTritonGPUToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, @@ -13,4 +15,6 @@ void populateTritonGPUToLLVMPatterns( ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit); +} // namespace mlir::triton + #endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 8d7a772b1231..d5378f25d35f 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -179,10 +179,10 @@ class ConvertTritonGPUOpToLLVMPatternBase { // Key: {layout, shape, withCTAOffset} struct IndexCacheInfo { DenseMap, CacheKeyDenseMapInfo> - *baseIndexCache; + *baseIndexCache = nullptr; DenseMap>, - CacheKeyDenseMapInfo> *indexCache; - OpBuilder::InsertPoint *indexInsertPoint; + CacheKeyDenseMapInfo> *indexCache = nullptr; + OpBuilder::InsertPoint *indexInsertPoint = nullptr; }; explicit ConvertTritonGPUOpToLLVMPatternBase( @@ -778,7 +778,7 @@ class ConvertTritonGPUOpToLLVMPatternBase { emitIndicesForDistributedLayout(loc, b, slice, type, withCTAOffset); } else { llvm_unreachable( - "emitIndices for layouts other than blocked & slice not " + "emitIndices for layouts other than blocked, mma, and slice not " "implemented yet"); } if (cache) { diff --git a/python/test/unit/language/print_helper.py b/python/test/unit/language/print_helper.py index c33ee9cee674..71d9c9031551 100644 --- a/python/test/unit/language/print_helper.py +++ b/python/test/unit/language/print_helper.py @@ -11,21 +11,47 @@ @triton.jit def kernel_device_print(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) - tl.device_print("", x) + tl.device_print("x: ", x) tl.store(Y + tl.arange(0, BLOCK), x) @triton.jit def kernel_print(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) - print("", x) + # Triton should add a space after this prefix. + print("x:", x) tl.store(Y + tl.arange(0, BLOCK), x) -# Take an extra value as a tl.constexpr so this kernel is not cached. This way -# the static print is run every time. +@triton.jit +def kernel_device_print_large( + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32) + # Triton should change this prefix to "x: ". + tl.device_print("x ", x) + + +@triton.jit +def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK,), 1, tl.int32) + print("", x, y) + + +@triton.jit +def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK,), 1, tl.int32) + tl.device_print("", x, y) + tl.store(Y + tl.arange(0, BLOCK), y) + + @triton.jit def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr): + # This function takes an extra value as a tl.constexpr so this kernel is not + # cached. This way the static print is run every time. x = tl.load(X + tl.arange(0, BLOCK)) tl.static_print("", x) tl.store(Y + tl.arange(0, BLOCK), x) @@ -38,19 +64,27 @@ def kernel_no_arg_print(): def test_print(func: str, data_type: str): shape = (128, ) - # limit the range of integers so that the sum does not overflow x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type)) y = torch.zeros(shape, dtype=x.dtype, device="cuda") if func == "device_print": kernel_device_print[(1,)](x, y, BLOCK=shape[0]) elif func == "print": kernel_print[(1,)](x, y, BLOCK=shape[0]) + elif func == "device_print_large": + kernel_device_print_large[(1, 2)](BLOCK_M=64, BLOCK_N=128) + elif func == "print_multiple_args": + kernel_print_multiple_args[(1,)](x, y, BLOCK=shape[0]) + elif func == "device_print_multiple_args": + kernel_device_print_multiple_args[(1,)](x, y, BLOCK=shape[0]) elif func == "static_print": kernel_static_print[(1,)](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4()) elif func == "no_arg_print": kernel_no_arg_print[(1,)](num_warps=4) + else: + assert f"Unknown kernel: {func}" - if func != "no_arg_print": + if func != "no_arg_print" and func != "device_print_large" and \ + func != "print_multiple_args" and func != "device_print_multiple_args": assert_close(y, x) diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index 56b88f8a188f..d91d9fc89aad 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -1,6 +1,8 @@ +import itertools import os import subprocess import sys +from collections import Counter import pytest @@ -14,26 +16,53 @@ torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] +# TODO: Print with multiple operands @pytest.mark.parametrize("func_type, data_type", - [("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32"), ("no_arg_print", "int32")]) + [("device_print", data_type) for data_type in torch_types] + [ + ("print", "int32"), + ("static_print", "int32"), + ("no_arg_print", "int32"), + ("device_print_large", "int32"), + ("print_multiple_args", "int32"), + ("device_print_multiple_args", "int32"), + ]) def test_print(func_type: str, data_type: str): proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False) outs, _ = proc.communicate() - outs = outs.split() - new_lines = set() - for line in outs: - try: - value = line - if func_type != "static_print": - value = int(float(line)) - new_lines.add(value) - except Exception as e: - print(e) - if func_type != "static_print" and func_type != "no_arg_print": + outs = [line for line in outs.decode("UTF-8").split("\n") if line] + + # Format is + # pid (, , ) idx (, , ...) (operand ) + expected_lines = Counter() + if func_type == "print" or func_type == "device_print": for i in range(128): - assert i in new_lines - else: - assert len(new_lines) == 1 + line = f"pid (0, 0, 0) idx ({i:3}) x: {i}" + if data_type.startswith("float"): + line += ".000000" + expected_lines[line] = 1 + elif func_type == "static_print": + expected_lines[" int32[constexpr[128]]"] = 1 + elif func_type == "no_arg_print": + expected_lines["pid (0, 0, 0) idx (): 0"] = 128 + elif func_type == "device_print_large": + for i, j, k in itertools.product(range(2), range(64), range(128)): + expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1 + elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args": + for i in range(128): + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1 + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1 + + actual_lines = Counter() + for line in outs: + actual_lines[line] += 1 + + diff = Counter(actual_lines) + diff.subtract(expected_lines) + for line, delta in diff.items(): + if delta == 0: + continue + print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') + assert all(delta == 0 for delta in diff.values()) @pytest.mark.parametrize("func_type", assert_types) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 835fef198bd3..fa7ce5d9ccec 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1577,6 +1577,15 @@ def debug_barrier(builder: ir.builder) -> tl.tensor: def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor: + # It makes sense visually for prefix to end in ": "; make it so. Also, + # non-empty prefixes should start with " ". + if not prefix.endswith(" "): + prefix += " " + if not prefix.endswith(": "): + prefix = prefix[:-1] + ": " + if len(prefix) > 2 and not prefix.startswith(" "): + prefix = " " + prefix + new_args = [] for arg in args: new_args.append(arg.handle)