diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 1cd2d2955a8e..6c93538a24f5 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -106,7 +106,7 @@ jobs: run: | if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then echo '::set-output name=matrix-CUDA::[["a100-runner-set"], ["h100-runner-set"]]' - echo '::set-output name=matrix-HIP::[["self-hosted", "gfx942"]]' + echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx942"]]' echo '::set-output name=matrix-MACOS::[["macos-latest"]]' else echo '::set-output name=matrix-CUDA::["ubuntu-latest"]' diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 851ee4928cdb..1b4c46a26c5b 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -115,7 +115,7 @@ jobs: run: | if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then echo '::set-output name=matrix-CUDA::[["a100-runner-set"], ["h100-runner-set"]]' - echo '::set-output name=matrix-HIP::[["self-hosted", "gfx942"]]' + echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx942"]]' echo '::set-output name=matrix-MACOS::[["macos-latest"]]' else echo '::set-output name=matrix-CUDA::["ubuntu-latest"]' diff --git a/README.md b/README.md index b332edb653d6..a11e8b9c7711 100644 --- a/README.md +++ b/README.md @@ -117,9 +117,10 @@ arbitrary LLVM version. (probably because, in our build, users don't invoke cmake directly, but instead use setup.py). Teach vscode how to compile Triton as follows. - - Do a local build. + - Do a local build. Run command `pip install -e python` - Get the full path to the `compile_commands.json` file produced by the build: - `find python/build -name 'compile_commands.json | xargs readlink -f'` + `find python/build -name 'compile_commands.json' | xargs readlink -f`. + You might get a full path similar to `/Users/{username}/triton/python/build/cmake.macosx-11.1-arm64-cpython-3.12/compile_commands.json` - In vscode, install the [C/C++ extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode.cpptools), diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 89e0b23e4ce2..b209a02b4bb3 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -388,19 +388,12 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, /* ------------------------------------ */ // Returns CTA level thread idx -inline Value getThreadIdInCTA(RewriterBase &rewriter, Location loc) { +inline Value getThreadId(RewriterBase &rewriter, Location loc) { Value tid = rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); return rewriter.create(loc, i32_ty, tid); } -// Returns CTA level thread idx. -inline Value getThreadId(RewriterBase &rewriter, Location loc) { - Value tid = getThreadIdInCTA(rewriter, loc); - auto mod = rewriter.getBlock()->getParent()->getParentOfType(); - return tid; -} - // ----------------------------------------------------------------------- // Shared memory utilities // ----------------------------------------------------------------------- @@ -909,10 +902,12 @@ inline void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout, auto rank = shapePerCta.size(); assert(rank == 2 || rank == 3); SmallVector elemOffset(rank, 0); + auto elemStride = wmmaLayout.getVersion() == 1 ? 2 : 1; if (rank == 3) elemOffset[0] = ctaBatchOffset; for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { - elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem; + elemOffset[rank - 2] = + ctaOffsetX * shapePerCta[rank - 2] + elemStride * elem; elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1]; offsets.push_back(elemOffset); } @@ -958,8 +953,17 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, SmallVector multiDimBase(rank); - multiDimBase[rank - 2] = - add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0); + auto ver = wmmaLayout.getVersion(); + if (ver == 1) { + multiDimBase[rank - 2] = + add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0); + } else { + assert(ver == 2); + multiDimBase[rank - 2] = + add(mul(udiv(threadIdPerWarp, i32_val(mnkDim[2])), + i32_val(wmmaLayout.getSizePerThread()[rank - 2])), + offWarp0); + } multiDimBase[rank - 1] = add(laneId, offWarp1); // TODO: It is assumed when rank = 3, warpsPerCTA is set to @@ -1109,8 +1113,6 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter, } else if (auto mfmaLayout = mlir::dyn_cast(layout)) { result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type); } else if (auto wmmaLayout = mlir::dyn_cast(layout)) { - // TODO: support 2nd gen of WMMA - assert(wmmaLayout.getVersion() == 1); result = emitBaseIndexForWmmaLayout(loc, rewriter, wmmaLayout, type); } else if (auto sliceLayout = mlir::dyn_cast(layout)) { auto parentLayout = sliceLayout.getParent(); diff --git a/include/triton/Dialect/Triton/IR/Utility.h b/include/triton/Dialect/Triton/IR/Utility.h index 0ef59714733d..1ff63697ec0d 100644 --- a/include/triton/Dialect/Triton/IR/Utility.h +++ b/include/triton/Dialect/Triton/IR/Utility.h @@ -31,7 +31,11 @@ template Int ceil(Int m, Int n) { return (m + n - 1) / n; } /// Get the highest power of 2 divisor of an integer. template T highestPowOf2Divisor(T n) { - if (n == 0) { + // When n is 0 or min, return the highest power of 2. The min case is handled + // separately to avoid underflow when T is a signed integer. Technically + // in that case the correct divisor is -n, but this value is outside the + // range of possible values, so we take the next best alternative. + if (n == 0 || n == std::numeric_limits::min()) { return (static_cast(1) << (sizeof(T) * 8 - 2)); } return (n & (~(n - 1))); diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 34f0d5aabc0f..a813013161f6 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -577,7 +577,23 @@ class LinearLayout { // divideLeft and divideRight are the inverses of operator*. // - // If c = a * b, then a = c.divideRight(b) and b = c.divideLeft(a). + // Consider `a = c.divideRight(b)`, where `a` is a linear layout with + // `in-dims(a) == in-dims(b)` and `out-dims(a) == out-dims(c)`. We may remove + // some empty dimensions from `a` to form `a'` and still have `a' * b == c`. + // Therefore, there are multiple possible values that we could return for + // `(a * b).divideRight(b)` which would satisfy + // `((a * b).divideRight(b)) * b == a * b`. + // + // In the following example, we have `a * b == a' * b` when "in1" is an empty + // dimension that maps everything to 0: + // + // a = L("in1", "in2") -> ("out1", "out2") + // a' = L("in1") -> ("out1") + // b = L("in2") -> ("out2") + // + // divideLeft and divideRight resolve this ambiguity by always returning the + // "canonical" quotient, namely the one with the fewest possible size-zero + // input and output dimensions. // // TODO(jlebar): Implement divideLeft. // std::optional divideLeft(const LinearLayout &divisor); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index be68f416f4e9..933f062d8191 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -425,6 +425,7 @@ bool supportMFMATypes(Type a, Type b) { if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth()) return false; + auto F8E5M2 = TypeID::get(); auto F8E4M3FNUZ = TypeID::get(); auto F8E5M2FNUZ = TypeID::get(); auto F16 = TypeID::get(); @@ -435,6 +436,7 @@ bool supportMFMATypes(Type a, Type b) { {F32, F32}, {F16, F16}, {BF16, BF16}, + {F8E5M2, F8E5M2}, {F8E4M3FNUZ, F8E4M3FNUZ}, {F8E4M3FNUZ, F8E5M2FNUZ}, {F8E5M2FNUZ, F8E4M3FNUZ}, diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index a5208d77976d..b07c21a404fc 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -25,6 +25,10 @@ using ::mlir::LLVM::linearize; using namespace mlir::triton::gpu; +// XXX(Keren): A temporary knob to control the use of legacy MMA conversion +// because LinearLayout seems to have some performance issues. +constexpr bool useLegacyMMAConversion = false; + struct ConvertLayoutOpConversion : public ConvertOpToLLVMPattern { public: @@ -341,8 +345,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion const LinearLayout &dstLayout, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - // TODO(jlebar): Implement me. - return failure(); + // TODO(Keren): implement warp shuffle instead of using the general approach + // that uses shared memory + return transferWithinBlockOrGroup(op, srcLayout, dstLayout, adaptor, + rewriter); } LogicalResult @@ -378,6 +384,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion /*accumNumReplicates=*/1)) { return false; } + if (useLegacyMMAConversion) { + return false; + } return true; } if (isa(layout)) { diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 458a1ed9d9b6..893d15876dc0 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -814,8 +814,6 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0], multiDimCTAInRepId[1]); } else if (auto wmmaLayout = dyn_cast(layout)) { - // TODO: support 2nd gen of WMMA - assert(wmmaLayout.getVersion() == 1); emitWmmaOffsetForCTA(wmmaLayout, offsets, 0, multiDimCTAInRepId[0], multiDimCTAInRepId[1]); } diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index a65b9e64e2a5..a4f30fc503ba 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -565,17 +565,35 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { // For wmma with 16x16 output, each of the 32 threads holds 8 elements. // - // For the register (i.e., element) dimension, these 8 elements are along - // the matrix C's M dimension, with 1 consecutive elements spanning 1 row - // and then the next 1 row being a gap. + // The first version of WMMA layout has following specific: + // for the register (i.e., element) dimension, these 8 elements are + // along the matrix C's M dimension, with 1 consecutive elements + // spanning 1 row and then the next 1 row being a gap. // // For the lane (i.e., thread) dimension, these threads are along the // matrix C's N dimension, with 16 consecutive threads covering a whole // row and the next 16 threads start at the next row. - LinearLayout tileLayout( - {{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}}, - {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}}, - {outDimNames[order[0]], outDimNames[order[1]]}); + // + // The second version of wmma layout is less tricky: + // for the register dimension 8 elements are along the matrix C's M + // dimension. First 16 lanes take 0-8 elems along M, second 16 take 8-15. + // We have 16 pair of threads in each warp, one pair covers the whole + // column. + // + // Please also check explaining comments in TritonGPUAttrDefs.td at the + // AMDWmmaEncodingAttr section. + unsigned ver = getVersion(); + assert(ver == 1 || ver == 2); + LinearLayout tileLayout = + ver == 1 + ? LinearLayout( + {{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}) + : LinearLayout( + {{kRegister, {{0, 1}, {0, 2}, {0, 4}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 8}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); if (hasBatchDim) { assert(order[2] == 0); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 07ef6f3f40a3..5cc537d5fcb3 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -1070,6 +1070,13 @@ bool mlir::triton::preProcessLoopAndGetSchedule( coarseSchedule.dump(); }); + tt::CoarseSchedule::Cluster afterPrologue = + schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with prologue and epilogue:"); + coarseSchedule.dump(); + }); + SmallVector barriers; // Convert the loads into async loads and create the allocs. SmallVector allocs = @@ -1080,13 +1087,6 @@ bool mlir::triton::preProcessLoopAndGetSchedule( coarseSchedule.dump(); }); - tt::CoarseSchedule::Cluster afterPrologue = - schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages); - LLVM_DEBUG({ - LDBG("Coarse schedule with prologue and epilogue:"); - coarseSchedule.dump(); - }); - scheduleDependencies(forOp, coarseSchedule, numStages); LLVM_DEBUG({ LDBG("Coarse schedule with dependencies:"); @@ -1402,8 +1402,7 @@ static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, transitiveOperand = cast(blockArg.getOwner()->getTerminator()) .getOperand(blockArg.getArgNumber() - 1); - } - if (Operation *def = transitiveOperand.getDefiningOp()) { + } else if (Operation *def = transitiveOperand.getDefiningOp()) { transitiveOperand = def->getOperand(0); } } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp index ee28f481dbc1..41c9b91a31cf 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -332,16 +332,26 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { if (annotateFn) annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i); for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { - setValueMapping(op->getResult(destId), newOp->getResult(destId), - i - stages[op]); + Value source = newOp->getResult(destId); // If the value is a loop carried dependency update the loop argument - // mapping. for (OpOperand &operand : yield->getOpOperands()) { if (operand.get() != op->getResult(destId)) continue; + if (predicates[predicateIdx] && + !forOp.getResult(operand.getOperandNumber()).use_empty()) { + // If the value is used outside the loop, we need to make sure we + // return the correct version of it. + Value prevValue = valueMapping + [forOp.getRegionIterArgs()[operand.getOperandNumber()]] + [i - stages[op]]; + source = rewriter.create( + loc, predicates[predicateIdx], source, prevValue); + } setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], - newOp->getResult(destId), i - stages[op] + 1); + source, i - stages[op] + 1); } + setValueMapping(op->getResult(destId), newOp->getResult(destId), + i - stages[op]); } } } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 4e5ea94a3396..57e41e55ff4f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -734,8 +734,11 @@ getConvertBackwardSlice(Value root, SetVector &slice, continue; enqueue(result, encoding); } - if (!isFreeConvert(definingOp) && - canFoldIntoConversion(definingOp, encoding)) + if (isFreeConvert(definingOp)) { + enqueue(definingOp->getOperand(0), encoding); + continue; + } + if (canFoldIntoConversion(definingOp, encoding)) continue; if (stopPropagation && stopPropagation(definingOp)) continue; diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 60938f2b737c..f31bfd56180a 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -165,6 +165,36 @@ void assertDimsSubsetIgnoringOrder(T &&small, U &&big) { } } +// Check that elements common to both aDims and bDims +// appear in the same relative order. +template +void assertCommonDimsSameOrder(T &&aDims, U &&bDims) { + SmallDenseSet aDimsSet(aDims.begin(), aDims.end()); + SmallDenseSet bDimsSet(bDims.begin(), bDims.end()); + + std::vector aCommonDims; + for (StringAttr dim : aDims) { + if (bDimsSet.contains(dim)) { + aCommonDims.push_back(dim); + } + } + + std::vector bCommonDims; + for (StringAttr dim : bDims) { + if (aDimsSet.contains(dim)) { + bCommonDims.push_back(dim); + } + } + + if (aCommonDims != bCommonDims) { + llvm::report_fatal_error("All a/b dimensions common to both layouts " + "must appear in the same relative order, but they " + "don't.\na:" + + Twine(triton::join(aDims, ", ")) + + "\nb: " + triton::join(bDims, ", ")); + } +} + void eraseEmptyInOutDims(BasesT &bases, llvm::MapVector &outDims) { // Erase empty out-dims. @@ -553,40 +583,9 @@ LinearLayout LinearLayout::reshapeOuts( } LinearLayout operator*(LinearLayout inner, LinearLayout outer) { - // Check that elements common to both outerDimsRange and innerDimsRange - // appear in the same relative order. - auto checkCommonDims = [&](auto outerDimsRange, auto innerDimsRange) { - SmallDenseSet outerDims(outerDimsRange.begin(), - outerDimsRange.end()); - SmallDenseSet innerDims(innerDimsRange.begin(), - innerDimsRange.end()); - - std::vector outerCommonDims; - for (StringAttr dim : outerDimsRange) { - if (innerDims.contains(dim)) { - outerCommonDims.push_back(dim); - } - } - - std::vector innerCommonDims; - for (StringAttr dim : innerDimsRange) { - if (outerDims.contains(dim)) { - innerCommonDims.push_back(dim); - } - } - - if (outerCommonDims != innerCommonDims) { - llvm::report_fatal_error( - "Cannot multiply layouts. All in/out dimensions common to both " - "layouts must appear in the same relative order, but they " - "don't.\nOuter:" + - Twine(outer.toString()) + "\nInner:" + inner.toString()); - } - }; - // Check that dims common to outer and inner have the same relative order. - checkCommonDims(outer.getInDimNames(), inner.getInDimNames()); - checkCommonDims(outer.getOutDimNames(), inner.getOutDimNames()); + assertCommonDimsSameOrder(inner.getOutDimNames(), outer.getOutDimNames()); + assertCommonDimsSameOrder(inner.getInDimNames(), outer.getInDimNames()); // Get the sizeLog2 of all input and output dimensions we're going to // consider, in order. `inner` is more minor, so its dimensions come @@ -642,6 +641,9 @@ LinearLayout operator*(LinearLayout inner, LinearLayout outer) { std::optional LinearLayout::divideRight(const LinearLayout &divisor) { + assertCommonDimsSameOrder(getOutDimNames(), divisor.getOutDimNames()); + assertCommonDimsSameOrder(getInDimNames(), divisor.getInDimNames()); + // Strip off the top N bases for each input dimension of divisor. This // gives a candidate quotient. Then check if quotient * divisor equals // `this`. @@ -655,35 +657,135 @@ LinearLayout::divideRight(const LinearLayout &divisor) { divisor.getInDimSizeLog2(inDim)); } + // Check if the size of the new out-dims are large enough. + // If yes, we can divide the out-dims. + // If no, we return nullopt to indicate that the division is not possible. llvm::MapVector newOutDims = outDims; - for (const auto [outDim, outDimSize] : divisor.outDims) { - if (newOutDims[outDim] < outDimSize) { + for (const auto [outDimName, outDimSize] : divisor.outDims) { + if (newOutDims[outDimName] < outDimSize) { return std::nullopt; } - newOutDims[outDim] /= outDimSize; + newOutDims[outDimName] /= outDimSize; } - eraseEmptyInOutDims(newBases, newOutDims); - - LDBG("this->divideRight(divisor)=candidate_quotient"); + LDBG("Checking candidate_quotient * divisor == *this"); LDBG("this:" << *this); LDBG("divisor:" << divisor); + LDBG("newBases: " << triton::join(newBases, ", ", [](auto &p) { + return p.first.str() + "=" + std::to_string(p.second.size()); + })); LDBG("newOutDims: " << triton::join(newOutDims, ", ", [](auto &p) { return p.first.str() + "=" + std::to_string(p.second); })); std::optional candidateQuotient = LinearLayout::tryCreate( - std::move(newBases), std::move(newOutDims).takeVector(), + std::move(newBases), std::move(newOutDims.takeVector()), /*requireSurjective=*/false); + LDBG("candidate_quotient:" << candidateQuotient); + LDBG("*candidate_quotient * divisor=" << *candidateQuotient * divisor); if (!candidateQuotient.has_value()) { LDBG("candidate quotient failed invariant checks"); return std::nullopt; } - LDBG("candidate_quotient:" << candidateQuotient); + if (*candidateQuotient * divisor != *this) { + LDBG("candidate quotient failed invariant checks"); + return std::nullopt; + } + + // Now that we have a candidate quotient, we need to eliminate any empty + // dimensions from the candidate quotient but still ensure that + // quotient * divisor == *this. + newBases = candidateQuotient->bases; + newOutDims = candidateQuotient->outDims; + + // We only remove the trailing empty output dimensions from `quotient`. + // + // In the multiplication `quotient * divisor == result`, the output dimensions + // of `quotient` always come before those of `divisor` in `result`. Removing + // any non-trailing empty dimensions from `quotient` would change the + // order of the output dimensions in `result`. + // + // The following loop iterates through the output dimensions of `result` from + // right to left. During the iteration, the following conditions are checked: + // + // 1. If an output dimension exists only in `divisor` and not in `quotient`, + // the loop continues. + // 2. If an output dimension exists only in `quotient` and not in `divisor`, + // we stop the loop. + // 3. If an output dimension exists in both `quotient` and `divisor`, it may + // be removed, but only if it is a size-1 dimension and meets one of the + // following conditions: + // - The dimension immediately following it in `quotient` has already been + // removed. + // - It is the last dimension of `quotient`. + // Otherwise, removing this dimension could alter the structure of `result`. + // + // Consider the quotient l = o / r, where: + // out-dims(o) = ["out0", "out1", "out2", "out3"] + // out-dims(r) = ["out1", "out3"] + // + // Only "out1" is a size-1 dimension. If we remove "out1" from o, the + // resulting output dimensions would be: + // out-dims(l) = ["out0", "out2", "out3"] + // + // Performing the multiplication l * r results in: + // out-dims(l * r) = ["out0", "out2", "out3"] * ["out1", "out3"] = ["out0", + // "out2", "out3", "out1"] + // This outcome does not match the original out-dims(o). + // + // However, if we remove only "out3" from o, we get: + // out-dims(l) = ["out0", "out1", "out2"] + // + // Then, performing the multiplication l * r yields: + // out-dims(l * r) = ["out0", "out1", "out2"] * ["out1", "out3"] = ["out0", + // "out1", "out2", "out3"] + // This result matches the original out-dims(o). + llvm::SmallVector emptyOutDimIndices; + for (const auto [outDimName, outDimSize] : llvm::reverse(outDims)) { + if (newOutDims.contains(outDimName) && !divisor.hasOutDim(outDimName)) { + break; + } + if (newOutDims.contains(outDimName) && divisor.hasOutDim(outDimName) && + candidateQuotient->getOutDimSize(outDimName) == 1) { + auto lastOutDimName = newOutDims.rbegin()->first; + if (outDimName != lastOutDimName) { + break; + } + emptyOutDimIndices.push_back(getOutDimIndex(outDimName)); + newOutDims.erase(outDimName); + } + } + + // Erase the basis elements corresponding to the empty out-dims. + for (auto &[inDim, inDimBases] : newBases) { + for (auto &basis : inDimBases) { + for (int i : emptyOutDimIndices) { + basis.erase(basis.begin() + i); + } + } + } - if (*candidateQuotient * divisor == *this) { - return *candidateQuotient; + // Erase trailing empty in-dims. + for (auto inDimName : llvm::reverse(getInDimNames())) { + if (newBases[inDimName].empty() && divisor.hasInDim(inDimName)) { + newBases.erase(inDimName); + } else { + break; + } } - return std::nullopt; + + LDBG("Eliminated empty dims from candidate_quotient"); + LDBG("newBases: " << triton::join(newBases, ", ", [](auto &p) { + return p.first.str() + "=" + std::to_string(p.second.size()); + })); + LDBG("newOutDims: " << triton::join(newOutDims, ", ", [](auto &p) { + return p.first.str() + "=" + std::to_string(p.second); + })); + auto quotient = LinearLayout::tryCreate(std::move(newBases), + std::move(newOutDims).takeVector(), + /*requireSurjective=*/false); + LDBG("quotient:" << quotient); + assert(quotient.has_value()); + return quotient; } LinearLayout LinearLayout::sublayout(ArrayRef inDimNames, diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 66561fdf6d98..ebb22639630d 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -116,6 +116,8 @@ std::string translateLLVMIRToASM(llvm::Module &module, opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; opt.TrapUnreachable = true; + opt.MCOptions.AsmVerbose = true; + opt.MCOptions.PreserveAsmComments = true; std::unique_ptr machine{target->createTargetMachine( module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, std::nullopt, diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 64a3a2c4eb81..9e5ff8a2ce37 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3082,7 +3082,7 @@ def convert_fp8_to_fp32(x, device, dtype_str): [(*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack) for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] - for input_precision in ["ieee" if is_hip() else "tf32"] + for input_precision in ["tf32" if is_cuda() else "ieee"] for col_a in [True, False] for col_b in [True, False] for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', @@ -3338,7 +3338,7 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ if out_dtype_str == "float16": pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") else: - input_precision = "tf32" if in_dtype_str == 'float32' else "ieee" + input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee" if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32": if triton.runtime.driver.active.utils.get_device_properties( @@ -5235,10 +5235,6 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s if in_type_str != 'float8e5': pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.') - ## TODO: Figure out why block size (128, 256, 128) fails on MI300 - if ("gfx94" in get_arch()) and BLOCK_M == 128: - pytest.skip('BLOCK size (128, 256, 128) fails on MI300') - check_type_supported(in_type_str, device) A = numpy_random((M, K), dtype_str=in_type_str) B = numpy_random((K, N), dtype_str=in_type_str) @@ -5469,7 +5465,7 @@ def maxnreg_noinline2(X): def test_maxnreg(device): assert not is_interpreter(), "this test won't work with the interpreter" - if is_hip(): + if not is_cuda(): pytest.skip('maxnreg only works on CUDA') # triton kernel @@ -5551,7 +5547,7 @@ def kernel(input): @pytest.mark.parametrize("dtype_str", ['float32', 'float64']) -def test_math_extern(dtype_str): +def test_math_extern(dtype_str, device): if is_interpreter(): pytest.skip('math_extern does not work in the interpreter mode') @@ -5575,8 +5571,8 @@ def kernel( x = numpy_random(shape, dtype_str=dtype_str, rs=rs) y_ref = np.tanh(x) - x_tri = to_triton(x, device='cuda') - y_tri = to_triton(numpy_random(shape, dtype_str=dtype_str, rs=rs), device='cuda') + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str=dtype_str, rs=rs), device=device) kernel[(1, )](x_tri, y_tri, shape[0], BLOCK_SIZE=shape[0]) # compare np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) diff --git a/python/test/unit/runtime/test_bindings.py b/python/test/unit/runtime/test_bindings.py index b0b99197b0a2..b3aebc9d8526 100644 --- a/python/test/unit/runtime/test_bindings.py +++ b/python/test/unit/runtime/test_bindings.py @@ -75,10 +75,11 @@ def walk_fn(op): backend = triton.compiler.compiler.make_backend(target) options = backend.parse_options(dict()) codegen_fns = dict() + module_map = backend.get_module_map() triton._C.libtriton.ir.load_dialects(context) backend.load_dialects(context) - ttir_module = src.make_ir(options, codegen_fns, context) + ttir_module = src.make_ir(options, codegen_fns, module_map, context) ttir_module.walk(walk_fn) diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 871bc6ba294b..8b793dd36095 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -45,3 +45,39 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, assert "remark: Warning: can't use MMA V3 for the dot op" in captured.err, "expect MMA V3 remark" assert "note: see current operation:" in captured.err os.environ['MLIR_ENABLE_REMARK'] = '0' + + +def test_remark_vectorization(capfd): + os.environ["MLIR_ENABLE_REMARK"] = "1" + + @triton.jit + def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + x0 = xindex % 9 + x2 = (xindex // 3456) % 512 + x1 = (xindex // 9) % 384 + x4 = xindex + tmp0 = tl.load(in_ptr0 + (x2 + (512 * x0)), None, eviction_policy="evict_last") + tmp1 = tmp0 + 520 + tmp2 = tmp0 < 0 + tmp3 = tl.where(tmp2, tmp1, tmp0) + tmp9 = (-4) + tmp3 + tmp12 = tl.full([1], 512, tl.int64) + tmp14 = tmp9 < tmp12 + tmp16 = tl.load(in_ptr3 + (x1), tmp14, eviction_policy="evict_last", other=0.0) + tmp18 = tmp16.to(tl.float32) + tmp19 = tmp18.to(tl.float32) + tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype) + tmp21 = tl.where(tmp14, tmp19, tmp20) + tmp22 = tmp21.to(tl.float32) + tl.store(out_ptr0 + (x4), tmp22, None) + + XBLOCK = 1024 + triton.compile( + triton.compiler.ASTSource(fn=ldst_vec, signature={0: '*i64', 1: '*i64', 2: '*fp16', 3: '*fp32', 4: '*fp16'}, + constants={"XBLOCK": XBLOCK}), options={"num_warps": 1}) + + _, err = capfd.readouterr() + assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark" + os.environ["MLIR_ENABLE_REMARK"] = "0" diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 990690045204..890c81461e6f 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -4,7 +4,8 @@ from abc import ABCMeta, abstractmethod, abstractclassmethod from dataclasses import dataclass -from typing import Union +from typing import Dict, Union +from types import ModuleType @dataclass(frozen=True) @@ -74,3 +75,10 @@ def load_dialects(self, context): Load additional MLIR dialects into the provided `context` """ raise NotImplementedError + + @abstractmethod + def get_module_map(self) -> Dict[str, ModuleType]: + """ + Return a map of interface modules to their device-specific implementations. + """ + raise NotImplementedError diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 96b7346ac554..ee3426bd4c65 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -188,8 +188,8 @@ def visit_Call(self, node: ast.Call) -> bool: class CodeGenerator(ast.NodeVisitor): def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, - codegen_fns, debug=None, module=None, is_kernel=False, function_types: Optional[Dict] = None, - noinline=False, file_name: Optional[str] = None, begin_line=0): + codegen_fns, module_map, debug=None, module=None, is_kernel=False, + function_types: Optional[Dict] = None, noinline=False, file_name: Optional[str] = None, begin_line=0): self.context = context self.builder = ir.builder(context) self.file_name = file_name @@ -201,10 +201,23 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n # Convert custom types not natively supported on HW. # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) self.builder.codegen_fns = codegen_fns + self.builder.module_map = {} if module_map is None else module_map self.module = self.builder.create_module() if module is None else module self.function_ret_types = {} if function_types is None else function_types self.prototype = prototype - self.gscope = gscope + + self.gscope = {} + for k, v in gscope.items(): + if isinstance(v, ModuleType): + self.gscope[k] = module_map.get(v.__name__, v) + continue + + module_name = getattr(v, "__module__", "") + if module_name in module_map: + self.gscope[k] = getattr(module_map[module_name], k) + else: + self.gscope[k] = v + self.lscope = {} self.attributes = attributes self.constants = constants @@ -1054,7 +1067,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, noinline=fn.noinline, file_name=file_name, begin_line=begin_line, - options=self.builder.options, codegen_fns=self.builder.codegen_fns, debug=debug) + options=self.builder.options, codegen_fns=self.builder.codegen_fns, + module_map=self.builder.module_map, debug=debug) try: generator.visit(fn.parse()) except Exception as e: @@ -1257,7 +1271,7 @@ def kernel_suffix(signature, specialization): return suffix -def ast_to_ttir(fn, specialization, context, options, codegen_fns): +def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): attrs = specialization.attrs # create kernel prototype cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i @@ -1277,7 +1291,7 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns): prototype = language.function_type([], arg_types) generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name, - begin_line=begin_line, options=options, codegen_fns=codegen_fns) + begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map) generator.visit(fn.parse()) ret = generator.module diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 78c79b4a5aeb..4156dbbd73fc 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -109,8 +109,9 @@ def hash(self): key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" return hashlib.sha256(key.encode("utf-8")).hexdigest() - def make_ir(self, options, codegen_fns, context): - return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns) + def make_ir(self, options, codegen_fns, module_map, context): + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, + module_map=module_map) def parse_options(self): return dict() @@ -132,7 +133,7 @@ def __init__(self, path): def hash(self): return hashlib.sha256(self.src.encode("utf-8")).hexdigest() - def make_ir(self, options, codegen_fns, context): + def make_ir(self, options, codegen_fns, module_map, context): module = ir.parse_mlir_module(self.path, context) module.context = context return module @@ -277,8 +278,9 @@ def compile(src, target=None, options=None): ir.load_dialects(context) backend.load_dialects(context) codegen_fns = backend.get_codegen_implementation() + module_map = backend.get_module_map() try: - module = src.make_ir(options, codegen_fns, context) + module = src.make_ir(options, codegen_fns, module_map, context) except Exception as e: filter_traceback(e) raise @@ -286,9 +288,8 @@ def compile(src, target=None, options=None): for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) ir_filename = f"{file_name}.{ext}" - if (fn_override_manager is not None and fn_override_manager.has_file(ir_filename)): - print(f"\nOverriding kernel with file {ir_filename}") - full_name = fn_override_manager.get_file(ir_filename) + if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): + print(f"\nOverriding kernel with file {full_name}") next_module = parse(full_name, ext, context) metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) if fn_dump_manager is not None: diff --git a/python/triton/language/extra/libdevice.py b/python/triton/language/extra/libdevice.py index 625cf3957e56..76627035de78 100644 --- a/python/triton/language/extra/libdevice.py +++ b/python/triton/language/extra/libdevice.py @@ -1,1213 +1,786 @@ -from .cuda import libdevice as cuda_libdevice -from .hip import libdevice as hip_libdevice -from triton.language import core -from functools import wraps -from typing import TypeVar - -T = TypeVar('T') - - -def dispatch(fn: T) -> T: - """Dispatch a function to a correct implementation.""" - assert callable(fn) - - @wraps(fn) - def wrapper(*args, **kwargs): - _backend = kwargs["_builder"].options.backend_name - if _backend == 'cuda': - _curr_libdevice_module = cuda_libdevice - elif _backend == 'hip': - _curr_libdevice_module = hip_libdevice - else: - raise RuntimeError('unknown backend') - - try: - _impl = getattr(_curr_libdevice_module, fn.__name__) - except AttributeError: - raise RuntimeError(f'`{_backend}` does not provide support for `{fn.__name__}` extra function') - - return _impl(*args, **kwargs) - - return wrapper - - -@core.extern -@dispatch -def clz(arg0, _builder=None): +def clz(arg0): ... -@core.extern -@dispatch -def popc(arg0, _builder=None): +def popc(arg0): ... -@core.extern -@dispatch -def byte_perm(arg0, arg1, arg2, _builder=None): +def byte_perm(arg0, arg1, arg2): ... -@core.extern -@dispatch -def mulhi(arg0, arg1, _builder=None): +def mulhi(arg0, arg1): ... -@core.extern -@dispatch -def mul24(arg0, arg1, _builder=None): +def mul24(arg0, arg1): ... -@core.extern -@dispatch -def brev(arg0, _builder=None): +def brev(arg0): ... -@core.extern -@dispatch -def sad(arg0, arg1, arg2, _builder=None): +def sad(arg0, arg1, arg2): ... -@core.extern -@dispatch -def abs(arg0, _builder=None): +def abs(arg0): ... -@core.extern -@dispatch -def floor(arg0, _builder=None): +def floor(arg0): ... -@core.extern -@dispatch -def rcp64h(arg0, _builder=None): +def rcp64h(arg0): ... -@core.extern -@dispatch -def rsqrt(arg0, _builder=None): +def rsqrt(arg0): ... -@core.extern -@dispatch -def ceil(arg0, _builder=None): +def ceil(arg0): ... -@core.extern -@dispatch -def trunc(arg0, _builder=None): +def trunc(arg0): ... -@core.extern -@dispatch -def exp2(arg0, _builder=None): +def exp2(arg0): ... -@core.extern -@dispatch -def saturatef(arg0, _builder=None): +def saturatef(arg0): ... -@core.extern -@dispatch -def fma_rn(arg0, arg1, arg2, _builder=None): +def fma_rn(arg0, arg1, arg2): ... -@core.extern -@dispatch -def fma_rz(arg0, arg1, arg2, _builder=None): +def fma_rz(arg0, arg1, arg2): ... -@core.extern -@dispatch -def fma_rd(arg0, arg1, arg2, _builder=None): +def fma_rd(arg0, arg1, arg2): ... -@core.extern -@dispatch -def fma_ru(arg0, arg1, arg2, _builder=None): +def fma_ru(arg0, arg1, arg2): ... -@core.extern -@dispatch -def fast_dividef(arg0, arg1, _builder=None): +def fast_dividef(arg0, arg1): ... -@core.extern -@dispatch -def div_rn(arg0, arg1, _builder=None): +def div_rn(arg0, arg1): ... -@core.extern -@dispatch -def div_rz(arg0, arg1, _builder=None): +def div_rz(arg0, arg1): ... -@core.extern -@dispatch -def div_rd(arg0, arg1, _builder=None): +def div_rd(arg0, arg1): ... -@core.extern -@dispatch -def div_ru(arg0, arg1, _builder=None): +def div_ru(arg0, arg1): ... -@core.extern -@dispatch -def rcp_rn(arg0, _builder=None): +def rcp_rn(arg0): ... -@core.extern -@dispatch -def rcp_rz(arg0, _builder=None): +def rcp_rz(arg0): ... -@core.extern -@dispatch -def rcp_rd(arg0, _builder=None): +def rcp_rd(arg0): ... -@core.extern -@dispatch -def rcp_ru(arg0, _builder=None): +def rcp_ru(arg0): ... -@core.extern -@dispatch -def sqrt_rn(arg0, _builder=None): +def sqrt_rn(arg0): ... -@core.extern -@dispatch -def sqrt_rz(arg0, _builder=None): +def sqrt_rz(arg0): ... -@core.extern -@dispatch -def sqrt_rd(arg0, _builder=None): +def sqrt_rd(arg0): ... -@core.extern -@dispatch -def sqrt_ru(arg0, _builder=None): +def sqrt_ru(arg0): ... -@core.extern -@dispatch -def sqrt(arg0, _builder=None): +def sqrt(arg0): ... -@core.extern -@dispatch -def add_rn(arg0, arg1, _builder=None): +def add_rn(arg0, arg1): ... -@core.extern -@dispatch -def add_rz(arg0, arg1, _builder=None): +def add_rz(arg0, arg1): ... -@core.extern -@dispatch -def add_rd(arg0, arg1, _builder=None): +def add_rd(arg0, arg1): ... -@core.extern -@dispatch -def add_ru(arg0, arg1, _builder=None): +def add_ru(arg0, arg1): ... -@core.extern -@dispatch -def mul_rn(arg0, arg1, _builder=None): +def mul_rn(arg0, arg1): ... -@core.extern -@dispatch -def mul_rz(arg0, arg1, _builder=None): +def mul_rz(arg0, arg1): ... -@core.extern -@dispatch -def mul_rd(arg0, arg1, _builder=None): +def mul_rd(arg0, arg1): ... -@core.extern -@dispatch -def mul_ru(arg0, arg1, _builder=None): +def mul_ru(arg0, arg1): ... -@core.extern -@dispatch -def double2float_rn(arg0, _builder=None): +def double2float_rn(arg0): ... -@core.extern -@dispatch -def double2float_rz(arg0, _builder=None): +def double2float_rz(arg0): ... -@core.extern -@dispatch -def double2float_rd(arg0, _builder=None): +def double2float_rd(arg0): ... -@core.extern -@dispatch -def double2float_ru(arg0, _builder=None): +def double2float_ru(arg0): ... -@core.extern -@dispatch -def double2int_rn(arg0, _builder=None): +def double2int_rn(arg0): ... -@core.extern -@dispatch -def double2int_rz(arg0, _builder=None): +def double2int_rz(arg0): ... -@core.extern -@dispatch -def double2int_rd(arg0, _builder=None): +def double2int_rd(arg0): ... -@core.extern -@dispatch -def double2int_ru(arg0, _builder=None): +def double2int_ru(arg0): ... -@core.extern -@dispatch -def double2uint_rn(arg0, _builder=None): +def double2uint_rn(arg0): ... -@core.extern -@dispatch -def double2uint_rz(arg0, _builder=None): +def double2uint_rz(arg0): ... -@core.extern -@dispatch -def double2uint_rd(arg0, _builder=None): +def double2uint_rd(arg0): ... -@core.extern -@dispatch -def double2uint_ru(arg0, _builder=None): +def double2uint_ru(arg0): ... -@core.extern -@dispatch -def int2double_rn(arg0, _builder=None): +def int2double_rn(arg0): ... -@core.extern -@dispatch -def uint2double_rn(arg0, _builder=None): +def uint2double_rn(arg0): ... -@core.extern -@dispatch -def float2int_rn(arg0, _builder=None): +def float2int_rn(arg0): ... -@core.extern -@dispatch -def float2int_rz(arg0, _builder=None): +def float2int_rz(arg0): ... -@core.extern -@dispatch -def float2int_rd(arg0, _builder=None): +def float2int_rd(arg0): ... -@core.extern -@dispatch -def float2int_ru(arg0, _builder=None): +def float2int_ru(arg0): ... -@core.extern -@dispatch -def float2uint_rn(arg0, _builder=None): +def float2uint_rn(arg0): ... -@core.extern -@dispatch -def float2uint_rz(arg0, _builder=None): +def float2uint_rz(arg0): ... -@core.extern -@dispatch -def float2uint_rd(arg0, _builder=None): +def float2uint_rd(arg0): ... -@core.extern -@dispatch -def float2uint_ru(arg0, _builder=None): +def float2uint_ru(arg0): ... -@core.extern -@dispatch -def int2float_rn(arg0, _builder=None): +def int2float_rn(arg0): ... -@core.extern -@dispatch -def int2float_rz(arg0, _builder=None): +def int2float_rz(arg0): ... -@core.extern -@dispatch -def int2float_rd(arg0, _builder=None): +def int2float_rd(arg0): ... -@core.extern -@dispatch -def int2float_ru(arg0, _builder=None): +def int2float_ru(arg0): ... -@core.extern -@dispatch -def uint2float_rn(arg0, _builder=None): +def uint2float_rn(arg0): ... -@core.extern -@dispatch -def uint2float_rz(arg0, _builder=None): +def uint2float_rz(arg0): ... -@core.extern -@dispatch -def uint2float_rd(arg0, _builder=None): +def uint2float_rd(arg0): ... -@core.extern -@dispatch -def uint2float_ru(arg0, _builder=None): +def uint2float_ru(arg0): ... -@core.extern -@dispatch -def hiloint2double(arg0, arg1, _builder=None): +def hiloint2double(arg0, arg1): ... -@core.extern -@dispatch -def double2loint(arg0, _builder=None): +def double2loint(arg0): ... -@core.extern -@dispatch -def double2hiint(arg0, _builder=None): +def double2hiint(arg0): ... -@core.extern -@dispatch -def float2ll_rn(arg0, _builder=None): +def float2ll_rn(arg0): ... -@core.extern -@dispatch -def float2ll_rz(arg0, _builder=None): +def float2ll_rz(arg0): ... -@core.extern -@dispatch -def float2ll_rd(arg0, _builder=None): +def float2ll_rd(arg0): ... -@core.extern -@dispatch -def float2ll_ru(arg0, _builder=None): +def float2ll_ru(arg0): ... -@core.extern -@dispatch -def float2ull_rn(arg0, _builder=None): +def float2ull_rn(arg0): ... -@core.extern -@dispatch -def float2ull_rz(arg0, _builder=None): +def float2ull_rz(arg0): ... -@core.extern -@dispatch -def float2ull_rd(arg0, _builder=None): +def float2ull_rd(arg0): ... -@core.extern -@dispatch -def float2ull_ru(arg0, _builder=None): +def float2ull_ru(arg0): ... -@core.extern -@dispatch -def double2ll_rn(arg0, _builder=None): +def double2ll_rn(arg0): ... -@core.extern -@dispatch -def double2ll_rz(arg0, _builder=None): +def double2ll_rz(arg0): ... -@core.extern -@dispatch -def double2ll_rd(arg0, _builder=None): +def double2ll_rd(arg0): ... -@core.extern -@dispatch -def double2ll_ru(arg0, _builder=None): +def double2ll_ru(arg0): ... -@core.extern -@dispatch -def double2ull_rn(arg0, _builder=None): +def double2ull_rn(arg0): ... -@core.extern -@dispatch -def double2ull_rz(arg0, _builder=None): +def double2ull_rz(arg0): ... -@core.extern -@dispatch -def double2ull_rd(arg0, _builder=None): +def double2ull_rd(arg0): ... -@core.extern -@dispatch -def double2ull_ru(arg0, _builder=None): +def double2ull_ru(arg0): ... -@core.extern -@dispatch -def ll2float_rn(arg0, _builder=None): +def ll2float_rn(arg0): ... -@core.extern -@dispatch -def ll2float_rz(arg0, _builder=None): +def ll2float_rz(arg0): ... -@core.extern -@dispatch -def ll2float_rd(arg0, _builder=None): +def ll2float_rd(arg0): ... -@core.extern -@dispatch -def ll2float_ru(arg0, _builder=None): +def ll2float_ru(arg0): ... -@core.extern -@dispatch -def ull2float_rn(arg0, _builder=None): +def ull2float_rn(arg0): ... -@core.extern -@dispatch -def ull2float_rz(arg0, _builder=None): +def ull2float_rz(arg0): ... -@core.extern -@dispatch -def ull2float_rd(arg0, _builder=None): +def ull2float_rd(arg0): ... -@core.extern -@dispatch -def ull2float_ru(arg0, _builder=None): +def ull2float_ru(arg0): ... -@core.extern -@dispatch -def ll2double_rn(arg0, _builder=None): +def ll2double_rn(arg0): ... -@core.extern -@dispatch -def ll2double_rz(arg0, _builder=None): +def ll2double_rz(arg0): ... -@core.extern -@dispatch -def ll2double_rd(arg0, _builder=None): +def ll2double_rd(arg0): ... -@core.extern -@dispatch -def ll2double_ru(arg0, _builder=None): +def ll2double_ru(arg0): ... -@core.extern -@dispatch -def ull2double_rn(arg0, _builder=None): +def ull2double_rn(arg0): ... -@core.extern -@dispatch -def ull2double_rz(arg0, _builder=None): +def ull2double_rz(arg0): ... -@core.extern -@dispatch -def ull2double_rd(arg0, _builder=None): +def ull2double_rd(arg0): ... -@core.extern -@dispatch -def ull2double_ru(arg0, _builder=None): +def ull2double_ru(arg0): ... -@core.extern -@dispatch -def int_as_float(arg0, _builder=None): +def int_as_float(arg0): ... -@core.extern -@dispatch -def float_as_int(arg0, _builder=None): +def float_as_int(arg0): ... -@core.extern -@dispatch -def uint_as_float(arg0, _builder=None): +def uint_as_float(arg0): ... -@core.extern -@dispatch -def float_as_uint(arg0, _builder=None): +def float_as_uint(arg0): ... -@core.extern -@dispatch -def longlong_as_double(arg0, _builder=None): +def longlong_as_double(arg0): ... -@core.extern -@dispatch -def double_as_longlong(arg0, _builder=None): +def double_as_longlong(arg0): ... -@core.extern -@dispatch -def fast_sinf(arg0, _builder=None): +def fast_sinf(arg0): ... -@core.extern -@dispatch -def fast_cosf(arg0, _builder=None): +def fast_cosf(arg0): ... -@core.extern -@dispatch -def fast_log2f(arg0, _builder=None): +def fast_log2f(arg0): ... -@core.extern -@dispatch -def fast_logf(arg0, _builder=None): +def fast_logf(arg0): ... -@core.extern -@dispatch -def fast_expf(arg0, _builder=None): +def fast_expf(arg0): ... -@core.extern -@dispatch -def fast_tanf(arg0, _builder=None): +def fast_tanf(arg0): ... -@core.extern -@dispatch -def fast_exp10f(arg0, _builder=None): +def fast_exp10f(arg0): ... -@core.extern -@dispatch -def fast_log10f(arg0, _builder=None): +def fast_log10f(arg0): ... -@core.extern -@dispatch -def fast_powf(arg0, arg1, _builder=None): +def fast_powf(arg0, arg1): ... -@core.extern -@dispatch -def hadd(arg0, arg1, _builder=None): +def hadd(arg0, arg1): ... -@core.extern -@dispatch -def rhadd(arg0, arg1, _builder=None): +def rhadd(arg0, arg1): ... -@core.extern -@dispatch -def sub_rn(arg0, arg1, _builder=None): +def sub_rn(arg0, arg1): ... -@core.extern -@dispatch -def sub_rz(arg0, arg1, _builder=None): +def sub_rz(arg0, arg1): ... -@core.extern -@dispatch -def sub_rd(arg0, arg1, _builder=None): +def sub_rd(arg0, arg1): ... -@core.extern -@dispatch -def sub_ru(arg0, arg1, _builder=None): +def sub_ru(arg0, arg1): ... -@core.extern -@dispatch -def rsqrt_rn(arg0, _builder=None): +def rsqrt_rn(arg0): ... -@core.extern -@dispatch -def ffs(arg0, _builder=None): +def ffs(arg0): ... -@core.extern -@dispatch -def rint(arg0, _builder=None): +def rint(arg0): ... -@core.extern -@dispatch -def llrint(arg0, _builder=None): +def llrint(arg0): ... -@core.extern -@dispatch -def nearbyint(arg0, _builder=None): +def nearbyint(arg0): ... -@core.extern -@dispatch -def isnan(arg0, _builder=None): +def isnan(arg0): ... -@core.extern -@dispatch -def signbit(arg0, _builder=None): +def signbit(arg0): ... -@core.extern -@dispatch -def copysign(arg0, arg1, _builder=None): +def copysign(arg0, arg1): ... -@core.extern -@dispatch -def finitef(arg0, _builder=None): +def finitef(arg0): ... -@core.extern -@dispatch -def isinf(arg0, _builder=None): +def isinf(arg0): ... -@core.extern -@dispatch -def nextafter(arg0, arg1, _builder=None): +def nextafter(arg0, arg1): ... -@core.extern -@dispatch -def sin(arg0, _builder=None): +def sin(arg0): ... -@core.extern -@dispatch -def cos(arg0, _builder=None): +def cos(arg0): ... -@core.extern -@dispatch -def sinpi(arg0, _builder=None): +def sinpi(arg0): ... -@core.extern -@dispatch -def cospi(arg0, _builder=None): +def cospi(arg0): ... -@core.extern -@dispatch -def tan(arg0, _builder=None): +def tan(arg0): ... -@core.extern -@dispatch -def log2(arg0, _builder=None): +def log2(arg0): ... -@core.extern -@dispatch -def exp(arg0, _builder=None): +def exp(arg0): ... -@core.extern -@dispatch -def exp10(arg0, _builder=None): +def exp10(arg0): ... -@core.extern -@dispatch -def cosh(arg0, _builder=None): +def cosh(arg0): ... -@core.extern -@dispatch -def sinh(arg0, _builder=None): +def sinh(arg0): ... -@core.extern -@dispatch -def tanh(arg0, _builder=None): +def tanh(arg0): ... -@core.extern -@dispatch -def atan2(arg0, arg1, _builder=None): +def atan2(arg0, arg1): ... -@core.extern -@dispatch -def atan(arg0, _builder=None): +def atan(arg0): ... -@core.extern -@dispatch -def asin(arg0, _builder=None): +def asin(arg0): ... -@core.extern -@dispatch -def acos(arg0, _builder=None): +def acos(arg0): ... -@core.extern -@dispatch -def log(arg0, _builder=None): +def log(arg0): ... -@core.extern -@dispatch -def log10(arg0, _builder=None): +def log10(arg0): ... -@core.extern -@dispatch -def log1p(arg0, _builder=None): +def log1p(arg0): ... -@core.extern -@dispatch -def acosh(arg0, _builder=None): +def acosh(arg0): ... -@core.extern -@dispatch -def asinh(arg0, _builder=None): +def asinh(arg0): ... -@core.extern -@dispatch -def atanh(arg0, _builder=None): +def atanh(arg0): ... -@core.extern -@dispatch -def expm1(arg0, _builder=None): +def expm1(arg0): ... -@core.extern -@dispatch -def hypot(arg0, arg1, _builder=None): +def hypot(arg0, arg1): ... -@core.extern -@dispatch -def rhypot(arg0, arg1, _builder=None): +def rhypot(arg0, arg1): ... -@core.extern -@dispatch -def norm3d(arg0, arg1, arg2, _builder=None): +def norm3d(arg0, arg1, arg2): ... -@core.extern -@dispatch -def rnorm3d(arg0, arg1, arg2, _builder=None): +def rnorm3d(arg0, arg1, arg2): ... -@core.extern -@dispatch -def norm4d(arg0, arg1, arg2, arg3, _builder=None): +def norm4d(arg0, arg1, arg2, arg3): ... -@core.extern -@dispatch -def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): +def rnorm4d(arg0, arg1, arg2, arg3): ... -@core.extern -@dispatch -def cbrt(arg0, _builder=None): +def cbrt(arg0): ... -@core.extern -@dispatch -def rcbrt(arg0, _builder=None): +def rcbrt(arg0): ... -@core.extern -@dispatch -def j0(arg0, _builder=None): +def j0(arg0): ... -@core.extern -@dispatch -def j1(arg0, _builder=None): +def j1(arg0): ... -@core.extern -@dispatch -def y0(arg0, _builder=None): +def y0(arg0): ... -@core.extern -@dispatch -def y1(arg0, _builder=None): +def y1(arg0): ... -@core.extern -@dispatch -def yn(arg0, arg1, _builder=None): +def yn(arg0, arg1): ... -@core.extern -@dispatch -def jn(arg0, arg1, _builder=None): +def jn(arg0, arg1): ... -@core.extern -@dispatch -def cyl_bessel_i0(arg0, _builder=None): +def cyl_bessel_i0(arg0): ... -@core.extern -@dispatch -def cyl_bessel_i1(arg0, _builder=None): +def cyl_bessel_i1(arg0): ... -@core.extern -@dispatch -def erf(arg0, _builder=None): +def erf(arg0): ... -@core.extern -@dispatch -def erfinv(arg0, _builder=None): +def erfinv(arg0): ... -@core.extern -@dispatch -def erfc(arg0, _builder=None): +def erfc(arg0): ... -@core.extern -@dispatch -def erfcx(arg0, _builder=None): +def erfcx(arg0): ... -@core.extern -@dispatch -def erfcinv(arg0, _builder=None): +def erfcinv(arg0): ... -@core.extern -@dispatch -def normcdfinv(arg0, _builder=None): +def normcdfinv(arg0): ... -@core.extern -@dispatch -def normcdf(arg0, _builder=None): +def normcdf(arg0): ... -@core.extern -@dispatch -def lgamma(arg0, _builder=None): +def lgamma(arg0): ... -@core.extern -@dispatch -def ldexp(arg0, arg1, _builder=None): +def ldexp(arg0, arg1): ... -@core.extern -@dispatch -def scalbn(arg0, arg1, _builder=None): +def scalbn(arg0, arg1): ... -@core.extern -@dispatch -def fmod(arg0, arg1, _builder=None): +def fmod(arg0, arg1): ... -@core.extern -@dispatch -def remainder(arg0, arg1, _builder=None): +def remainder(arg0, arg1): ... -@core.extern -@dispatch -def fma(arg0, arg1, arg2, _builder=None): +def fma(arg0, arg1, arg2): ... -@core.extern -@dispatch -def pow(arg0, arg1, _builder=None): +def pow(arg0, arg1): ... -@core.extern -@dispatch -def tgamma(arg0, _builder=None): +def tgamma(arg0): ... -@core.extern -@dispatch -def round(arg0, _builder=None): +def round(arg0): ... -@core.extern -@dispatch -def llround(arg0, _builder=None): +def llround(arg0): ... -@core.extern -@dispatch -def fdim(arg0, arg1, _builder=None): +def fdim(arg0, arg1): ... -@core.extern -@dispatch -def ilogb(arg0, _builder=None): +def ilogb(arg0): ... -@core.extern -@dispatch -def logb(arg0, _builder=None): +def logb(arg0): ... -@core.extern -@dispatch -def isfinited(arg0, _builder=None): +def isfinited(arg0): ... diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index 6da420896787..4b2f9141fd2f 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -433,7 +433,7 @@ def flip(x, dim=None): def interleave(a, b): """ Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape. - Equivalent to `tl.join(a, b).reshape(a.shape[-1:] + [2 * a.shape[-1]])` + Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])` :param a: The first input tensor. :type a: Tensor @@ -442,7 +442,6 @@ def interleave(a, b): """ c = core.join(a, b) - assert isinstance(c.shape, list) if len(c.shape) == 1: # We must have interleaved two scalars. return c diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index d5e777ee9a6c..82b2fea37e9b 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Dict, List, Optional +import base64 import hashlib @@ -255,6 +256,11 @@ def put_group(self, filename: str, group: Dict[str, str]): __cache_cls_nme = "DEFAULT" +def _base64(key): + # Assume key is a hex string. + return base64.urlsafe_b64encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") + + def get_cache_manager(key) -> CacheManager: import os @@ -268,15 +274,15 @@ def get_cache_manager(key) -> CacheManager: __cache_cls = getattr(module, clz_nme) __cache_cls_nme = user_cache_manager - return __cache_cls(key) + return __cache_cls(_base64(key)) def get_override_manager(key) -> CacheManager: - return __cache_cls(key, override=True) + return __cache_cls(_base64(key), override=True) def get_dump_manager(key) -> CacheManager: - return __cache_cls(key, dump=True) + return __cache_cls(_base64(key), dump=True) def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): @@ -286,4 +292,4 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): for kw in kwargs: key = f"{key}-{kwargs.get(kw)}" key = hashlib.sha256(key.encode("utf-8")).hexdigest() - return key + return _base64(key) diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 4161c08ad573..d2664b959deb 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -863,3 +863,14 @@ tt.func public @chained_for(%8: tensor<128x64x!tt.ptr> {tt.divisibility = } tt.return } + +// ----- + +// CHECK-LABEL: @int_min_does_not_underflow_in_analysis +module { + tt.func @int_min_does_not_underflow_in_analysis() -> i64 { + // CHECK: divisibility = [4611686018427387904] + %int_min = arith.constant -9223372036854775808 : i64 + tt.return %int_min : i64 + } +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index eed4269fa5e2..afba22f6a29d 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -17,6 +17,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: basic_load tt.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm + // CHECK-SAME: mov.u32 $0, $1; + // CHECK-SAME: @$3 ld.global.b32 { $0 }, [ $2 + 0 ];", "=r,r,l,b" // CHECK: llvm.inline_asm %1 = tt.load %a_ptr_init, %cst, %cst_0 : tensor<256x!tt.ptr, #blocked0> tt.return @@ -707,39 +709,39 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: convert_layout_blocked_blocked tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: llvm.inline_asm + // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.inline_asm + // CHECK: ld.shared + // CHECK: llvm.inline_asm + // CHECK: ld.shared + // CHECK: llvm.inline_asm + // CHECK: ld.shared + // CHECK: llvm.inline_asm + // CHECK: ld.shared + // CHECK: llvm.inline_asm + // CHECK: ld.shared + // CHECK: llvm.inline_asm + // CHECK: ld.shared + // CHECK: llvm.inline_asm + // CHECK: ld.shared + // CHECK: llvm.inline_asm + // CHECK: ld.shared %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -754,15 +756,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: convert_layout_blocked_blocked_vec tt.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: llvm.inline_asm + // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.inline_asm + // CHECK: ld.shared + // CHECK: llvm.inline_asm + // CHECK: ld.shared %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -777,21 +779,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: convert_layout_blocked_blocked_multi_rep tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.inline_asm + // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.inline_asm + // CHECK: ld.shared + // CHECK: llvm.inline_asm + // CHECK: ld.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.inline_asm + // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.inline_asm + // CHECK: ld.shared + // CHECK: llvm.inline_asm + // CHECK: ld.shared %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir new file mode 100644 index 000000000000..7854a4eed7a5 --- /dev/null +++ b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir @@ -0,0 +1,20 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx940 matrix-instruction-size=0' | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}> +// CHECK-LABEL: mfma_dot_fp8e5m2 +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_fp8e5m2( + %arg0: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x256xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<128x256x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> + // CHECK: %[[A0:.+]] = triton_gpu.convert_layout %arg0 : {{.*}} -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + // CHECK: %[[A1:.+]] = tt.fp_to_fp %[[A0]] : {{.*}} -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + // CHECK: %[[B0:.+]] = triton_gpu.convert_layout %arg1 : {{.*}} -> tensor<64x256xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + // CHECK: %[[B1:.+]] = tt.fp_to_fp %[[B0]] : tensor<64x256xf8E5M2, {{.*}} -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + // CHECK: tt.dot %[[A1]], %[[B1]] + %1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + tt.store %arg2, %1 : tensor<128x256x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index bf966ae2af74..83b3fed52ac8 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -1,7 +1,8 @@ from triton.backends.compiler import BaseBackend, GPUTarget from triton._C.libtriton import ir, passes, llvm, amd from dataclasses import dataclass -from typing import Any, Tuple +from typing import Any, Dict, Tuple +from types import ModuleType import hashlib import tempfile import os @@ -95,6 +96,10 @@ def get_codegen_implementation(self): codegen_fns = {"min_dot_size": min_dot_size(self.target)} return codegen_fns + def get_module_map(self) -> Dict[str, ModuleType]: + from triton.language.extra.hip import libdevice + return {"triton.language.extra.libdevice": libdevice} + def load_dialects(self, ctx): amd.load_dialects(ctx) diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 3825273265f0..c1ff6e1d65f3 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -13,6 +13,54 @@ include_dir = [os.path.join(dirname, "include")] +def _find_already_mmapped_dylib_on_linux(lib_name): + import platform + if platform.system() != 'Linux': + return None + + # Use dl_iterate_phdr to walk through the list of shared libraries at runtime. + # See https://www.man7.org/linux/man-pages/man3/dl_iterate_phdr.3.html for details. + + import ctypes + from ctypes import c_char, c_int, c_size_t, c_void_p, c_char_p, POINTER + + class DlPhdrInfo(ctypes.Structure): + _fields_ = [ + ('dlpi_addr', c_void_p), + ('dlpi_name', c_char_p), + # We don't care about the remaining fields. + ] + + # callback_t must use POINTER(c_char) to avoid copying. + callback_t = ctypes.CFUNCTYPE(c_int, POINTER(DlPhdrInfo), POINTER(c_size_t), POINTER(c_char)) + + # Load libc and get the dl_iterate_phdr symbol. + try: + dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr + except: + return None + # argtypes must use c_char_p to accept create_string_buffer. + dl_iterate_phdr.argtypes = [callback_t, c_char_p] + dl_iterate_phdr.restype = c_int + + max_path_length = 4096 + path = ctypes.create_string_buffer(max_path_length + 1) + + # Define callback to get the loaded dylib path. + def callback(info, size, data): + dlpi_name = info.contents.dlpi_name + p = Path(os.fsdecode(dlpi_name)) + if lib_name in p.name: + # Found the dylib; get its path. + ctypes.memmove(data, dlpi_name, min(max_path_length, len(dlpi_name))) + return 1 + return 0 + + if dl_iterate_phdr(callback_t(callback), path): + return os.fsdecode(ctypes.string_at(path)) + return None + + @functools.lru_cache() def _get_path_to_hip_runtime_dylib(): lib_name = "libamdhip64.so" @@ -24,6 +72,13 @@ def _get_path_to_hip_runtime_dylib(): return env_libhip_path raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}") + # If the shared object is already mmapped to address space, use it. + mmapped_path = _find_already_mmapped_dylib_on_linux(lib_name) + if mmapped_path: + if os.path.exists(mmapped_path): + return mmapped_path + raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}") + paths = [] import site diff --git a/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h b/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h index 4127d85dcd68..121bb617265f 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h @@ -86,6 +86,8 @@ class MfmaInsn { unsigned getNDim(); StringRef getInsnName(); unsigned getKBase(); + Type getElementTypeA(); + Type getElementTypeB(); }; } // namespace mlir diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index 79fa319ba978..23dd5d37d520 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -77,8 +77,6 @@ createNewConvertOps(ModuleOp &mod, OpBuilder &builder, srcType.getElementType(), newMfmaEnc); } else if (auto srcWmma = dyn_cast( srcType.getEncoding())) { - // TODO: support 2nd gen of WMMA - assert(srcWmma.getVersion() == 1); auto newWmmaEnc = triton::gpu::AMDWmmaEncodingAttr::get( mod.getContext(), srcWmma.getVersion(), {warpsPerCtaX, warpsPerCtaY}, srcWmma.getCTALayout()); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index f66cb449133b..bf976a8138dc 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -304,10 +304,8 @@ class BlockedToMFMA : public RewritePattern { /// @brief Choose MFMA instruction parameters /// @param dot target dot operation - /// @return pair {mDim, nDim, kDim, kBase} sizes of one MFMA instruction - /// arguments - std::tuple - chooseMfmaDimensions(tt::DotOp dot) const { + /// @return MfmaInsn or failure + FailureOr chooseMfmaInstruction(tt::DotOp dot) const { // number of matrix elements along k dim per one MFMA intruction unsigned kDim = 0; auto opType = cast(dot.getA().getType()); @@ -359,13 +357,10 @@ class BlockedToMFMA : public RewritePattern { llvm::report_fatal_error("No match found in MFMA database\n"); kDim = maybeMfmaInsn->getKDim(); - unsigned kBase = maybeMfmaInsn->getKBase(); - assert(kDim != 0); - assert(M % mDim == 0 && N % nDim == 0); assert(opType.getShape()[rank - 1] % kDim == 0); - return {mDim, nDim, kDim, kBase}; + return maybeMfmaInsn; } LogicalResult matchAndRewrite(Operation *op, @@ -396,7 +391,11 @@ class BlockedToMFMA : public RewritePattern { ttg::AMDMfmaEncodingAttr mfmaEnc; - auto [mDim, nDim, kDim, kBase] = chooseMfmaDimensions(dotOp); + auto mfmaInstr = chooseMfmaInstruction(dotOp); + auto mDim = mfmaInstr.value().getMDim(); + auto nDim = mfmaInstr.value().getNDim(); + auto kDim = mfmaInstr.value().getKDim(); + auto kBase = mfmaInstr.value().getKBase(); auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim}); @@ -457,14 +456,14 @@ class BlockedToMFMA : public RewritePattern { if (!isSecondDot(dotOp)) kWidth *= kPack; - auto newAType = RankedTensorType::get( - oldAType.getShape(), oldAType.getElementType(), - ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth)); - auto newBType = RankedTensorType::get( - oldBType.getShape(), oldBType.getElementType(), - ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth)); - a = rewriter.create(a.getLoc(), newAType, a); - b = rewriter.create(b.getLoc(), newBType, b); + auto newAEncoding = + ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth); + auto newBEncoding = + ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth); + a = convertAndCastTensor(rewriter, a, newAEncoding, + mfmaInstr.value().getElementTypeA()); + b = convertAndCastTensor(rewriter, b, newBEncoding, + mfmaInstr.value().getElementTypeB()); auto newDot = rewriter.create( dotOp.getLoc(), newAcc.getType(), a, b, newAcc, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp index 9207d155880a..d3b2b70f858c 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp @@ -2,7 +2,8 @@ namespace mlir { -static MfmaTypeId convertTypesToId(mlir::Type dataTypeA, mlir::Type dataTypeB) { +static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA, + mlir::Type dataTypeB) { if (dataTypeA.isF32() && dataTypeB.isF32()) { return MfmaTypeId::Fp32TyId; } @@ -27,6 +28,9 @@ static MfmaTypeId convertTypesToId(mlir::Type dataTypeA, mlir::Type dataTypeB) { if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) { return MfmaTypeId::Bf8Bf8TyId; } + if (dataTypeA.isFloat8E5M2() && dataTypeB.isFloat8E5M2()) { + return MfmaTypeId::Fp16TyId; + } llvm_unreachable("Unsupported input argument type."); } @@ -205,16 +209,48 @@ auto getMfmaInsnGroupAttrMap = []() -> const MfmaInsnGroupMap & { return MfmaInsnMap; }; +std::pair TypesFromMfmaId(mlir::MLIRContext *ctx, + MfmaTypeId id) { + auto f8e5m2 = Float8E5M2Type::get(ctx); + auto f8e4m3fnuz = Float8E4M3FNUZType::get(ctx); + auto f8e5m2fnuz = Float8E5M2FNUZType::get(ctx); + auto f16 = Float16Type::get(ctx); + auto bf16 = BFloat16Type::get(ctx); + auto f32 = Float32Type::get(ctx); + auto i8 = IntegerType::get(ctx, 8, IntegerType::Signed); + switch (id) { + case MfmaTypeId::Fp32TyId: + return {f32, f32}; + case MfmaTypeId::Fp16TyId: + return {f16, f16}; + case MfmaTypeId::Bf16TyId: + return {bf16, bf16}; + case MfmaTypeId::I8TyId: + return {i8, i8}; + case MfmaTypeId::Fp8Fp8TyId: + return {f8e4m3fnuz, f8e4m3fnuz}; + case MfmaTypeId::Fp8Bf8TyId: + return {f8e4m3fnuz, f8e5m2fnuz}; + case MfmaTypeId::Bf8Fp8TyId: + return {f8e5m2fnuz, f8e4m3fnuz}; + case MfmaTypeId::Bf8Bf8TyId: + return {f8e5m2fnuz, f8e5m2fnuz}; + } + assert(false && "unsupported MfmaTypeId"); +} + FailureOr MfmaInsn::selectMfma(unsigned mDim, unsigned nDim, Type elementTypeA, Type elementTypeB, int mfmaVersion) { auto mfmaInsnAttrMap = getMfmaInsnGroupAttrMap(); - MfmaInsnGroupSelectKey key = { - mDim, nDim, convertTypesToId(elementTypeA, elementTypeB), mfmaVersion}; + MfmaTypeId mfmaId = chooseAppropriateMfmaId(elementTypeA, elementTypeB); + MfmaInsnGroupSelectKey key = {mDim, nDim, mfmaId, mfmaVersion}; auto it = mfmaInsnAttrMap.find(key); if (it == mfmaInsnAttrMap.end()) return failure(); - return MfmaInsn(elementTypeA, elementTypeB, (*it).second); + auto [instrElementTypeA, instrElementTypeB] = + TypesFromMfmaId(elementTypeA.getContext(), mfmaId); + return MfmaInsn(instrElementTypeA, instrElementTypeB, it->second); } MfmaInsn::MfmaInsn(Type elementTypeA, Type elementTypeB, @@ -226,4 +262,6 @@ unsigned MfmaInsn::getMDim() { return attr.m; } unsigned MfmaInsn::getNDim() { return attr.n; } StringRef MfmaInsn::getInsnName() { return attr.insn; } unsigned MfmaInsn::getKBase() { return attr.kBase; } +Type MfmaInsn::getElementTypeA() { return elementTypeA; } +Type MfmaInsn::getElementTypeB() { return elementTypeB; } } // namespace mlir diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 5dd75e530fec..ea1d79f9ba93 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -3,7 +3,8 @@ from dataclasses import dataclass import functools -from typing import Any, Tuple, Optional +from typing import Any, Dict, Tuple, Optional +from types import ModuleType import hashlib import re import tempfile @@ -48,12 +49,15 @@ def ptx_get_version(cuda_version) -> int: assert isinstance(cuda_version, str) major, minor = map(int, cuda_version.split('.')) if major == 12: - return 80 + minor + if minor < 6: + return 80 + minor + elif minor == 6: + return 85 if major == 11: return 70 + minor if major == 10: return 63 + minor - raise RuntimeError("Triton only support CUDA 10.0 or higher") + raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version) @functools.lru_cache() @@ -155,6 +159,10 @@ def get_codegen_implementation(self): } return codegen_fns + def get_module_map(self) -> Dict[str, ModuleType]: + from triton.language.extra.cuda import libdevice + return {"triton.language.extra.libdevice": libdevice} + def load_dialects(self, ctx): nvidia.load_dialects(ctx) @@ -229,6 +237,11 @@ def make_llir(src, metadata, options, capability): # TritonGPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() + # Set up Diagnostic + if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1": + srcMgr = llvm.source_mgr() + diag = ir.source_mgr_diag(srcMgr, mod.context) + mod.context.printOpOnDiagnostic(True) nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.convert.add_scf_to_cf(pm) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index 2fabb598e99d..91e9a4bbf888 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -234,6 +234,21 @@ static const std::string S8_to_Bf16 = "prmt.b32 $0, f0, f1, 0x7632; \n" // f32->bf16 + pack "prmt.b32 $1, f2, f3, 0x7632; \n" // "}"; +// Conversions have low throughput, rely on bit tricks instead of cvt +// instruction on Hopper and later GPUs. +static const std::string S8_to_Bf16_sm90 = + "{ \n" + ".reg .b32 l<3>; \n" + ".reg .b32 h<3>; \n" + "prmt.b32 l0, $2, 0x43, 0x4140; \n" // Unpack to shifted bf16. + "prmt.b32 h0, $2, 0x43, 0x4342; \n" + "and.b32 l1, l0, 0xff7fff7f; \n" // Zero the least exp bit. + "and.b32 h1, h0, 0xff7fff7f; \n" + "and.b32 l2, l0, 0xff80ff80; \n" // Zero the mantissa. + "and.b32 h2, h0, 0xff80ff80; \n" + "sub.bf16x2 $0, l1, l2; \n" // Subtract the offset. + "sub.bf16x2 $1, h1, h2; \n" + "}"; typedef std::function(Location, ConversionPatternRewriter &, const SmallVector &)> @@ -646,9 +661,15 @@ struct FSubOpConversion struct SIToFPOpConversion : ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; - using Base::Base; using Adaptor = typename Base::OpAdaptor; + explicit SIToFPOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + int computeCapability, + PatternBenefit benefit = patternBenefitDefault) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + computeCapability(computeCapability) {} + SmallVector createDestOps(arith::SIToFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, @@ -657,7 +678,8 @@ struct SIToFPOpConversion Type outElemTy = getElementType(op.getOut()); if (outElemTy.isBF16() && inElemTy.isInteger(8) && operands.size() >= 4) { auto cvtFunc = makeConverterFromPtx( - S8_to_Bf16, getTypeConverter()->convertType(inElemTy), + computeCapability >= 90 ? S8_to_Bf16_sm90 : S8_to_Bf16, + getTypeConverter()->convertType(inElemTy), getTypeConverter()->convertType(outElemTy)); SmallVector inVals = {operands[0][0], operands[1][0], operands[2][0], operands[3][0]}; @@ -668,6 +690,9 @@ struct SIToFPOpConversion return {rewriter.create(loc, elemTy, operands[0][0])}; } } + +private: + int computeCapability; }; struct FPToSIOpConversion @@ -920,8 +945,9 @@ void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns( patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, + computeCapability, benefit); patterns.add(typeConverter, axisInfoAnalysis, computeCapability, benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 9ee532992d01..27d7dd69b8c0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -166,6 +166,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, typeConverter->convertType(getElementTypeOrSelf(op.getType())); unsigned vec = getVectorSize(ptr); unsigned numElems = getTotalElemsPerThread(ptr.getType()); + unsigned vecOrig = vec; if (llMask) { LLVM_DEBUG(DBGS() << "vec = " << vec << " mask_alignment = " << getMaskAlignment(mask)); @@ -173,6 +174,13 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, LLVM_DEBUG(llvm::dbgs() << " vec = " << vec << '\n'); } + if (vec == 1 && numElems > 1) { + int maskValue = !llMask ? -1 : getMaskAlignment(mask); + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " numElems = " << numElems << " mask is " << maskValue + << "\n"; + } // Get the LLVM values for pointers auto ptrElems = unpackLLElements(loc, llPtr, rewriter); assert(ptrElems.size() == numElems); @@ -237,40 +245,14 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, // prepare asm operands auto *dstsOpr = ptxBuilder.newListOperand(); + // If there is a `other` value, use it to init. + bool init = other == nullptr; for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { auto *opr = ptxBuilder.newOperand(writeConstraint, - /*init=*/true); // =r operations + init); // =r operations dstsOpr->listAppend(opr); } - auto *addrOpr = - ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); - - // Define the instruction opcode - auto &ld = ptxBuilder.create<>("ld") - ->o("volatile", op.getIsVolatile()) - .global() - .o("ca", op.getCache() == triton::CacheModifier::CA) - .o("cg", op.getCache() == triton::CacheModifier::CG) - .o("L1::evict_first", - op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) - .o("L1::evict_last", - op.getEvict() == triton::EvictionPolicy::EVICT_LAST) - .o("L1::cache_hint", hasL2EvictPolicy) - .v(nWords) - .b(width); - - PTXBuilder::Operand *evictOpr{}; - - // Here lack a mlir::Value to bind to this operation, so disabled. - // if (has_l2_evict_policy) - // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); - - if (!evictOpr) - ld(dstsOpr, addrOpr).predicate(pred, "b"); - else - ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); - if (other) { for (size_t ii = 0; ii < nWords; ++ii) { // PTX doesn't support mov.u8, so we need to use mov.u16 @@ -298,10 +280,38 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, } else opr = ptxBuilder.newOperand(v, readConstraint); - mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b"); + mov(dstsOpr->listGet(ii), opr); } } + auto *addrOpr = + ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); + + // Define the instruction opcode + auto &ld = ptxBuilder.create<>("ld") + ->o("volatile", op.getIsVolatile()) + .global() + .o("ca", op.getCache() == triton::CacheModifier::CA) + .o("cg", op.getCache() == triton::CacheModifier::CG) + .o("L1::evict_first", + op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) + .o("L1::evict_last", + op.getEvict() == triton::EvictionPolicy::EVICT_LAST) + .o("L1::cache_hint", hasL2EvictPolicy) + .v(nWords) + .b(width); + + PTXBuilder::Operand *evictOpr{}; + + // Here lack a mlir::Value to bind to this operation, so disabled. + // if (has_l2_evict_policy) + // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); + + if (!evictOpr) + ld(dstsOpr, addrOpr).predicate(pred, "b"); + else + ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); + // Create inline ASM signature SmallVector retTys(nWords, IntegerType::get(getContext(), width)); Type retTy = retTys.size() > 1 @@ -378,6 +388,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, assert(ptrElems.size() == valueElems.size()); // Determine the vectorization size + unsigned vecOrig = vec; SmallVector maskElems; if (llMask) { Value mask = op.getMask(); @@ -388,6 +399,14 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, vec = std::min(vec, maskAlign); } + if (vec == 1 && elemsPerThread > 1) { + int mask = !llMask ? -1 : getMaskAlignment(op.getMask()); + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " elemsPerThread = " << elemsPerThread << " mask is " + << mask << "\n"; + } + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); @@ -516,12 +535,18 @@ struct AtomicCASOpConversion auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); // vec = 1 for scalar auto vec = getVectorSize(op.getPtr()); + auto vecOrig = vec; // tensor if (tensorTy) { auto valTy = cast(op.getVal().getType()); vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); } + if (vec == 1 && elemsPerThread > 1) + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " elemsPerThread = " << elemsPerThread << "\n"; + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); @@ -639,6 +664,7 @@ struct AtomicRMWOpConversion auto elemsPerThread = getTotalElemsPerThread(val.getType()); // vec = 1, numElements = 1 for scalar auto vec = getVectorSize(ptr); + auto vecOrig = vec; int numElems = 1; // tensor if (tensorTy) { @@ -647,6 +673,12 @@ struct AtomicRMWOpConversion // mask numElems = tensorTy.getNumElements(); } + + if (vec == 1 && numElems > 1) + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " numElems = " << numElems; + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); diff --git a/third_party/proton/proton/viewer.py b/third_party/proton/proton/viewer.py index cf6fc1288bc7..f77a65007fc7 100644 --- a/third_party/proton/proton/viewer.py +++ b/third_party/proton/proton/viewer.py @@ -2,9 +2,12 @@ from collections import namedtuple import json import pandas as pd -import hatchet as ht +try: + import hatchet as ht + from hatchet.query import NegationQuery +except ImportError: + raise ImportError("Failed to import hatchet. `pip install llnl-hatchet` to get the correct version.") import numpy as np -from hatchet.query import NegationQuery from triton.profiler.hook import COMPUTE_METADATA_SCOPE_NAME, TritonHook diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index e2e8c5f92127..8be680562cae 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -600,6 +600,20 @@ TEST_F(LinearLayoutTest, DivideRight_EliminateInDim) { LinearLayout l3({{S("in2"), {{0, 1}, {1, 0}}}}, {S("out1"), S("out2")}); ASSERT_EQ(l3 * l2, l1); EXPECT_EQ(l1.divideRight(l2), l3); + + LinearLayout l4({{S("in1"), {{0, 1}, {0, 2}}}, {S("in2"), {}}}, + {S("out1"), S("out2")}); + LinearLayout l5({{S("in1"), {{0, 1}, {0, 2}}}}, {S("out1"), S("out2")}); + LinearLayout l6({{S("in2"), {}}}, {S("out1"), S("out2")}); + ASSERT_EQ(l5 * l6, l4); + EXPECT_EQ(l4.divideRight(l6), l5); + + LinearLayout l7({{S("in1"), {}}, {S("in2"), {{0, 1}}}, {S("in3"), {}}}, + {S("out1"), S("out2")}); + LinearLayout l8({{S("in2"), {{0, 1}}}}, {S("out1"), S("out2")}); + LinearLayout l9({{S("in1"), {}}, {S("in2"), {}}, {S("in3"), {}}}, {}); + ASSERT_EQ(l9 * l8, l7); + EXPECT_EQ(l7.divideRight(l8), l9); } TEST_F(LinearLayoutTest, DivideRight_EliminateOutDim) { @@ -613,6 +627,18 @@ TEST_F(LinearLayoutTest, DivideRight_EliminateOutDim) { LinearLayout l3({{S("in2"), {{1}, {1}}}}, {S("out1")}); ASSERT_EQ(l3 * l2, l1); EXPECT_EQ(l1.divideRight(l2), l3); + + LinearLayout l4( + { + {S("in1"), {{0, 1}, {0, 2}}}, + }, + {S("out1"), S("out2")}); + LinearLayout l5({{S("in1"), {{1}, {2}}}}, {S("out2")}); + using BasesArray = + ArrayRef>>>; + LinearLayout l6(BasesArray{}, {S("out1")}); + ASSERT_EQ(l6 * l5, l4); + EXPECT_EQ(l4.divideRight(l5), l6); } TEST_F(LinearLayoutTest, DivideRight_Assertion) {