From 4f398cabb49b1222e772177eb92f2a4d0d590c1d Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 16 Apr 2026 12:05:44 -0700 Subject: [PATCH 1/4] Localize warp specialization partition attrs --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 12 - lib/Dialect/TritonGPU/IR/Dialect.cpp | 181 --------------- .../WarpSpecialization/Partition.cpp | 213 ++++++++++++++++++ .../WarpSpecialization/PartitionAttrs.h | 34 +++ .../WarpSpecialization/PartitionLoops.cpp | 1 + .../PartitionScheduling.cpp | 7 + .../partition-verifier-locality.mlir | 16 ++ .../NVWS/Transforms/AssignStagePhase.cpp | 1 + .../NVWS/Transforms/HoistTmemStore.cpp | 1 + .../Dialect/NVWS/Transforms/InsertAref.cpp | 1 + .../NVWS/Transforms/InsertTmemAref.cpp | 1 + .../lib/Dialect/NVWS/Transforms/LowerAref.cpp | 1 + 12 files changed, 276 insertions(+), 193 deletions(-) create mode 100644 lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionAttrs.h create mode 100644 test/TritonGPU/partition-verifier-locality.mlir diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 5a13748bb2fd..f9a171a6368b 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -51,11 +51,6 @@ constexpr static char AttrNumWarpsName[] = "ttg.num-warps"; constexpr static char AttrNumCTAsName[] = "ttg.num-ctas"; constexpr static char AttrTargetName[] = "ttg.target"; constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp"; -// FIXME: rename to match above -constexpr static char kPartitionAttrName[] = "ttg.partition"; -constexpr static char kPartitionOutputsAttrName[] = "ttg.partition.outputs"; -constexpr static char kPartitionStagesAttrName[] = "ttg.partition.stages"; -constexpr static char kWarpSpecializeTagAttrName[] = "ttg.warp_specialize.tag"; // Find the contextual number of warps on which this operation is executed. int lookupNumWarps(Operation *op); @@ -335,13 +330,6 @@ LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy, ShapedType dstTy); // Verify a memory allocation operation. LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy); - -SetVector getPartitionIds(Operation *op); -SmallVector, 4> getPartitionOutputs(Operation *op); -SetVector getPartitionIds(OpOperand *use); -bool hasPartition(Operation *op); -bool hasWarpSpecializeTag(Operation *op); -std::optional getWarpSpecializeTag(Operation *op); /// Returns the size in bytes of a scalar type when stored in shared memory. size_t getSharedMemorySize(Type type); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 974e09792e48..3da2732e8cd3 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -4,7 +4,6 @@ #include #include -#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" @@ -4051,135 +4050,6 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, << " which is expected only on `module` or `tt.func` ops"; } - // Verify that all ops in a tt.warp_specialize op have partition ids - if (attr.getName() == "tt.warp_specialize") { - if (!isa(op)) { - return op->emitOpError("has unexpected attribute ") - << attr.getName() << " which is expected only on `scf.for` ops"; - } - Operation *failedOp = nullptr; - op->walk([&](Operation *childOp) { - if (!childOp->hasAttr(kPartitionAttrName)) { - failedOp = childOp; - WalkResult::interrupt(); - } - }); - if (failedOp) { - return failedOp->emitOpError("does not have expected attribute ") - << kPartitionAttrName - << " which is expected on all child ops of an op with " - "attribute `tt.warp_specialize`"; - } - } - - // Verify that partition id lists are non-empty, sorted and have no duplicates - auto verifyPartitionIds = - [&](const ArrayRef &partitionIds) -> LogicalResult { - SetVector idSet; - for (auto id : partitionIds) { - if (idSet.contains(id)) - return op->emitOpError("has duplicated partition ids in attribute ") - << attr.getName(); - idSet.insert(id); - } - if (idSet.empty()) - return op->emitOpError("has no partition ids in attribute ") - << attr.getName(); - auto ids = idSet.takeVector(); - SmallVector sortedIds(ids.begin(), ids.end()); - std::sort(sortedIds.begin(), sortedIds.end()); - if (ids != sortedIds) - return op->emitOpError("partition ids not in sorted order in attribute ") - << attr.getName(); - return success(); - }; - - if (attr.getName() == kPartitionAttrName) { - auto result = verifyPartitionIds( - cast(attr.getValue()).asArrayRef()); - if (failed(result)) - return result; - } - if (attr.getName() == kPartitionOutputsAttrName) { - auto arrayAttr = cast(attr.getValue()); - for (auto idx = 0; idx < arrayAttr.size(); idx++) { - auto result = verifyPartitionIds( - cast(arrayAttr[idx]).asArrayRef()); - if (failed(result)) - return result; - } - } - - // Verify that op partitions include partitions of all child ops - if (attr.getName() == kPartitionAttrName && op->getNumRegions() != 0) { - SetVector expectedIds; - for (auto ®ion : op->getRegions()) { - for (auto &block : region.getBlocks()) { - for (auto &childOp : block.getOperations()) { - if (isa(childOp)) { - // yield ops and ub.poison do not need partition ids - continue; - } - if (!childOp.hasAttr(kPartitionAttrName)) - return childOp.emitOpError("does not have expected attribute ") - << kPartitionAttrName - << " which is expected for ops whose parent has partitions"; - auto ids = getPartitionIds(&childOp); - expectedIds.insert(ids.begin(), ids.end()); - } - } - } - auto partitionIds = getPartitionIds(op); - for (auto id : expectedIds) { - if (!partitionIds.contains(id)) { - return op->emitOpError("partition ids in attr ") - << attr.getName() - << " does not contain partition ids of all child ops"; - } - } - } - - if (attr.getName() == kPartitionOutputsAttrName) { - if (!isa(op)) - return op->emitOpError("has unexpected attribute ") << attr.getName(); - - // Verify that number of output partitions matches number of For/If results - size_t numResults = 0; - if (isa(op)) { - numResults = cast(op).getResults().size(); - } else if (isa(op)) { - numResults = cast(op).getResults().size(); - } else { - numResults = cast(op).getResults().size(); - } - - if (cast(attr.getValue()).size() != numResults) { - return op->emitOpError("does not have expected number of output " - "partition sets in attr ") - << attr.getName() << "; should match number of results"; - } - - // Verify that union of op output partitions is a subset of op partitions - if (!op->hasAttr(kPartitionAttrName)) - return op->emitOpError("does not have expected attribute ") - << kPartitionAttrName << " which is expected for ops with attr " - << kPartitionOutputsAttrName; - auto partitionIds = getPartitionIds(op); - - SetVector outputPartitionIdsUnion; - for (auto outputPartitionIds : getPartitionOutputs(op)) { - outputPartitionIdsUnion.insert(outputPartitionIds.begin(), - outputPartitionIds.end()); - } - if (!std::all_of(outputPartitionIdsUnion.begin(), - outputPartitionIdsUnion.end(), - [&](int id) { return partitionIds.contains(id); })) { - return op->emitOpError("partition ids in attr ") - << kPartitionAttrName - << " must be the union of all partition ids in " << attr.getName(); - } - } - return success(); } @@ -4414,57 +4284,6 @@ SmallVector triton::gpu::getTMABlockShape( mode); } -SetVector triton::gpu::getPartitionIds(Operation *op) { - auto attrs = op->getAttr(kPartitionAttrName); - SmallVector partitionIds; - for (auto id : cast(attrs).asArrayRef()) { - partitionIds.push_back(id); - } - std::sort(partitionIds.begin(), partitionIds.end()); - return SetVector(partitionIds.begin(), partitionIds.end()); -} - -SmallVector, 4> triton::gpu::getPartitionOutputs(Operation *op) { - SmallVector, 4> partitionOutputsIds; - if (op->getNumResults() == 0) { - return partitionOutputsIds; - } - assert(op->hasAttr(kPartitionOutputsAttrName)); - auto arrayAttr = cast(op->getAttr(kPartitionOutputsAttrName)); - for (auto attr : arrayAttr) { - auto ids = cast(attr).asArrayRef(); - partitionOutputsIds.push_back(SetVector(ids.begin(), ids.end())); - } - return partitionOutputsIds; -} - -SetVector triton::gpu::getPartitionIds(OpOperand *use) { - auto owner = use->getOwner(); - if (isa(owner)) { - return getPartitionOutputs(owner->getParentOp())[use->getOperandNumber()]; - } else if (scf::ForOp forOp = dyn_cast(owner)) { - int idx = use->getOperandNumber() - forOp.getNumControlOperands(); - return idx >= 0 ? getPartitionOutputs(owner)[idx] : getPartitionIds(forOp); - } else { - return getPartitionIds(owner); - } -} - -bool triton::gpu::hasPartition(Operation *op) { - return op && op->hasAttr(kPartitionAttrName); -} - -bool triton::gpu::hasWarpSpecializeTag(Operation *op) { - return op && op->hasAttr(kWarpSpecializeTagAttrName); -} - -std::optional triton::gpu::getWarpSpecializeTag(Operation *op) { - if (hasWarpSpecializeTag(op)) { - return cast(op->getAttr(kWarpSpecializeTagAttrName)).getInt(); - } - return std::nullopt; -} - PaddedSharedEncodingAttr triton::gpu::getPaddedEncoding(Attribute encoding) { if (!encoding) return nullptr; diff --git a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp index 69ef8bf55c89..c0676ec9d65c 100644 --- a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp @@ -1,5 +1,9 @@ +#include "PartitionAttrs.h" #include "triton/Dialect/TritonGPU/Transforms/Partition.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinAttributes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/ADT/SCCIterator.h" @@ -10,6 +14,149 @@ using namespace mlir; using namespace triton; using namespace triton::gpu; +namespace { + +LogicalResult verifyPartitionIdsAttr(Operation *op, StringRef attrName, + Attribute attrValue) { + auto partitionIdsAttr = dyn_cast(attrValue); + if (!partitionIdsAttr) { + return op->emitOpError("has invalid attribute ") + << attrName << "; expected a dense i32 array"; + } + + SetVector idSet; + for (auto id : partitionIdsAttr.asArrayRef()) { + if (idSet.contains(id)) + return op->emitOpError("has duplicated partition ids in attribute ") + << attrName; + idSet.insert(id); + } + if (idSet.empty()) + return op->emitOpError("has no partition ids in attribute ") << attrName; + + auto ids = idSet.takeVector(); + SmallVector sortedIds(ids.begin(), ids.end()); + llvm::sort(sortedIds); + if (ids != sortedIds) { + return op->emitOpError("partition ids not in sorted order in attribute ") + << attrName; + } + return success(); +} + +LogicalResult verifyPartitionAttrs(Operation *op) { + if (op->hasAttr(kWarpSpecializeAttrName)) { + if (!isa(op)) { + return op->emitOpError("has unexpected attribute ") + << kWarpSpecializeAttrName + << " which is expected only on `scf.for` ops"; + } + + Operation *failedOp = nullptr; + op->walk([&](Operation *childOp) { + if (!childOp->hasAttr(kPartitionAttrName)) { + failedOp = childOp; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (failedOp) { + return failedOp->emitOpError("does not have expected attribute ") + << kPartitionAttrName + << " which is expected on all child ops of an op with attribute `" + << kWarpSpecializeAttrName << "`"; + } + } + + if (auto partitionAttr = op->getAttr(kPartitionAttrName)) { + if (failed(verifyPartitionIdsAttr(op, kPartitionAttrName, partitionAttr))) + return failure(); + } + + if (auto outputsAttr = op->getAttr(kPartitionOutputsAttrName)) { + auto arrayAttr = dyn_cast(outputsAttr); + if (!arrayAttr) { + return op->emitOpError("has invalid attribute ") + << kPartitionOutputsAttrName << "; expected an array attribute"; + } + + for (Attribute attr : arrayAttr) { + if (failed( + verifyPartitionIdsAttr(op, kPartitionOutputsAttrName, attr))) { + return failure(); + } + } + } + + if (op->hasAttr(kPartitionAttrName) && op->getNumRegions() != 0) { + SetVector expectedIds; + for (Region ®ion : op->getRegions()) { + for (Block &block : region.getBlocks()) { + for (Operation &childOp : block.getOperations()) { + if (isa(childOp)) + continue; + if (!childOp.hasAttr(kPartitionAttrName)) { + return childOp.emitOpError("does not have expected attribute ") + << kPartitionAttrName + << " which is expected for ops whose parent has partitions"; + } + auto ids = getPartitionIds(&childOp); + expectedIds.insert(ids.begin(), ids.end()); + } + } + } + + auto partitionIds = getPartitionIds(op); + for (auto id : expectedIds) { + if (!partitionIds.contains(id)) { + return op->emitOpError("partition ids in attr ") + << kPartitionAttrName + << " does not contain partition ids of all child ops"; + } + } + } + + if (auto outputsAttr = op->getAttr(kPartitionOutputsAttrName)) { + if (!isa(op)) + return op->emitOpError("has unexpected attribute ") + << kPartitionOutputsAttrName; + + size_t numResults = op->getNumResults(); + auto arrayAttr = cast(outputsAttr); + if (arrayAttr.size() != numResults) { + return op->emitOpError("does not have expected number of output " + "partition sets in attr ") + << kPartitionOutputsAttrName + << "; should match number of results"; + } + + if (!op->hasAttr(kPartitionAttrName)) { + return op->emitOpError("does not have expected attribute ") + << kPartitionAttrName << " which is expected for ops with attr " + << kPartitionOutputsAttrName; + } + + auto partitionIds = getPartitionIds(op); + SetVector outputPartitionIdsUnion; + for (auto outputPartitionIds : getPartitionOutputs(op)) { + outputPartitionIdsUnion.insert(outputPartitionIds.begin(), + outputPartitionIds.end()); + } + if (!llvm::all_of(outputPartitionIdsUnion, [&](int id) { + return partitionIds.contains(id); + })) { + return op->emitOpError("partition ids in attr ") + << kPartitionAttrName + << " must be the union of all partition ids in " + << kPartitionOutputsAttrName; + } + } + + return success(); +} + +} // namespace + //===----------------------------------------------------------------------===// // Partition //===----------------------------------------------------------------------===// @@ -132,6 +279,9 @@ Partition *PartitionSet::getPartition(Operation *op) { } FailureOr PartitionSet::fromLoop(scf::ForOp loop) { + if (failed(verifyPartitionedLoop(loop))) + return failure(); + auto stages = loop->getAttrOfType(kPartitionStagesAttrName); if (!stages) return failure(); @@ -187,6 +337,69 @@ void PartitionSet::dump() const { namespace mlir::triton::gpu { +SetVector getPartitionIds(Operation *op) { + auto attrs = op->getAttr(kPartitionAttrName); + SmallVector partitionIds; + for (auto id : cast(attrs).asArrayRef()) { + partitionIds.push_back(id); + } + llvm::sort(partitionIds); + return SetVector(partitionIds.begin(), partitionIds.end()); +} + +SmallVector, 4> getPartitionOutputs(Operation *op) { + SmallVector, 4> partitionOutputsIds; + if (op->getNumResults() == 0) + return partitionOutputsIds; + + assert(op->hasAttr(kPartitionOutputsAttrName)); + auto arrayAttr = cast(op->getAttr(kPartitionOutputsAttrName)); + for (Attribute attr : arrayAttr) { + auto ids = cast(attr).asArrayRef(); + partitionOutputsIds.push_back(SetVector(ids.begin(), ids.end())); + } + return partitionOutputsIds; +} + +SetVector getPartitionIds(OpOperand *use) { + auto owner = use->getOwner(); + if (isa(owner)) { + return getPartitionOutputs(owner->getParentOp())[use->getOperandNumber()]; + } + if (auto forOp = dyn_cast(owner)) { + int idx = use->getOperandNumber() - forOp.getNumControlOperands(); + return idx >= 0 ? getPartitionOutputs(owner)[idx] : getPartitionIds(forOp); + } + return getPartitionIds(owner); +} + +bool hasPartition(Operation *op) { return op && op->hasAttr(kPartitionAttrName); } + +bool hasWarpSpecializeTag(Operation *op) { + return op && op->hasAttr(kWarpSpecializeTagAttrName); +} + +std::optional getWarpSpecializeTag(Operation *op) { + if (hasWarpSpecializeTag(op)) + return cast(op->getAttr(kWarpSpecializeTagAttrName)).getInt(); + return std::nullopt; +} + +LogicalResult verifyPartitionedLoop(scf::ForOp loop) { + if (failed(verifyPartitionAttrs(loop))) + return failure(); + + LogicalResult result = success(); + loop.walk([&](Operation *op) { + if (failed(verifyPartitionAttrs(op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + void setPartition(Operation *op, ArrayRef partitionIds) { Builder b(op->getContext()); auto sorted = llvm::to_vector(partitionIds); diff --git a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionAttrs.h b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionAttrs.h new file mode 100644 index 000000000000..a2c888cf0ee3 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionAttrs.h @@ -0,0 +1,34 @@ +#ifndef TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_WARPSPECIALIZATION_PARTITIONATTRS_H_ +#define TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_WARPSPECIALIZATION_PARTITIONATTRS_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SetVector.h" +#include + +namespace mlir { +class Operation; +class OpOperand; +namespace scf { +class ForOp; +} // namespace scf +} // namespace mlir + +namespace mlir::triton::gpu { + +inline constexpr char kPartitionAttrName[] = "ttg.partition"; +inline constexpr char kPartitionOutputsAttrName[] = "ttg.partition.outputs"; +inline constexpr char kPartitionStagesAttrName[] = "ttg.partition.stages"; +inline constexpr char kWarpSpecializeTagAttrName[] = "ttg.warp_specialize.tag"; + +SetVector getPartitionIds(Operation *op); +SmallVector, 4> getPartitionOutputs(Operation *op); +SetVector getPartitionIds(OpOperand *use); +bool hasPartition(Operation *op); +bool hasWarpSpecializeTag(Operation *op); +std::optional getWarpSpecializeTag(Operation *op); + +LogicalResult verifyPartitionedLoop(scf::ForOp loop); + +} // namespace mlir::triton::gpu + +#endif // TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_WARPSPECIALIZATION_PARTITIONATTRS_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp index 5a4cb31fba1d..6fdb16453bd6 100644 --- a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp @@ -8,6 +8,7 @@ #include "mlir/Transforms/RegionUtils.h" #include "nvidia/include/Dialect/NVWS/IR/Dialect.h" #include "nvidia/include/Dialect/NVWS/Transforms/Passes.h" +#include "PartitionAttrs.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Partition.h" #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" diff --git a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp index e6447df41ef2..9885aa270a58 100644 --- a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp @@ -1,5 +1,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "PartitionAttrs.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Partition.h" @@ -1453,6 +1454,12 @@ struct PartitionScheduling analyze(idx, op); if (hasPartition(op)) cloneMultiPartitionDataOps(op); + if (auto loop = dyn_cast(op); + loop && loop->hasAttr(kPartitionStagesAttrName) && + failed(verifyPartitionedLoop(loop))) { + signalPassFailure(); + return; + } idx++; } } diff --git a/test/TritonGPU/partition-verifier-locality.mlir b/test/TritonGPU/partition-verifier-locality.mlir new file mode 100644 index 000000000000..aa883a699f50 --- /dev/null +++ b/test/TritonGPU/partition-verifier-locality.mlir @@ -0,0 +1,16 @@ +// RUN: triton-opt %s -allow-unregistered-dialect -verify-diagnostics -o /dev/null +// RUN: not triton-opt %s -allow-unregistered-dialect -tritongpu-partition-loops -o /dev/null 2>&1 | FileCheck %s + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @partition_attrs_are_verified_only_when_consumed( + %lb: i32, %ub: i32, %step: i32) { + scf.for %i = %lb to %ub step %step : i32 { + %0 = arith.addi %i, %i {ttg.partition = array} : i32 + "use"(%0) : (i32) -> () + } {ttg.partition.stages = [0, 0], ttg.warp_specialize.tag = 0 : i32, + ttg.partition = array} + tt.return + } +} + +// CHECK: error: 'arith.addi' op partition ids not in sorted order in attribute ttg.partition diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp index c6635fdb05f5..082d5c96f090 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp @@ -22,6 +22,7 @@ */ #include "Utilities.h" +#include "lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionAttrs.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "Utilities.h" diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/HoistTmemStore.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/HoistTmemStore.cpp index 4909facc38d2..7d10429b5e31 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/HoistTmemStore.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/HoistTmemStore.cpp @@ -22,6 +22,7 @@ */ #include "mlir/Analysis/SliceAnalysis.h" +#include "lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionAttrs.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Interfaces/InferIntRangeInterface.h" diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp index 240145a74e89..c83826772ac4 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp @@ -1,4 +1,5 @@ #include "Utilities.h" +#include "lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionAttrs.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Dominance.h" diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp index 6c08e85b3e1b..9f8f21f2b360 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp @@ -1,4 +1,5 @@ #include "Utilities.h" +#include "lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionAttrs.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/UB/IR/UBOps.h" diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp index 7ef4d05ffd6c..74fcaa06d853 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp @@ -22,6 +22,7 @@ */ #include "Utilities.h" +#include "lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionAttrs.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" From 1649591f227619f50948d9fad444d3364ed3fc7a Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 16 Apr 2026 12:35:46 -0700 Subject: [PATCH 2/4] Apply pre-commit formatting --- .../Transforms/WarpSpecialization/Partition.cpp | 17 ++++++++--------- .../WarpSpecialization/PartitionLoops.cpp | 2 +- .../WarpSpecialization/PartitionScheduling.cpp | 2 +- .../Dialect/NVWS/Transforms/HoistTmemStore.cpp | 2 +- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp index c0676ec9d65c..200b66b706a3 100644 --- a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp @@ -1,5 +1,5 @@ -#include "PartitionAttrs.h" #include "triton/Dialect/TritonGPU/Transforms/Partition.h" +#include "PartitionAttrs.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinAttributes.h" @@ -81,8 +81,7 @@ LogicalResult verifyPartitionAttrs(Operation *op) { } for (Attribute attr : arrayAttr) { - if (failed( - verifyPartitionIdsAttr(op, kPartitionOutputsAttrName, attr))) { + if (failed(verifyPartitionIdsAttr(op, kPartitionOutputsAttrName, attr))) { return failure(); } } @@ -126,8 +125,7 @@ LogicalResult verifyPartitionAttrs(Operation *op) { if (arrayAttr.size() != numResults) { return op->emitOpError("does not have expected number of output " "partition sets in attr ") - << kPartitionOutputsAttrName - << "; should match number of results"; + << kPartitionOutputsAttrName << "; should match number of results"; } if (!op->hasAttr(kPartitionAttrName)) { @@ -142,9 +140,8 @@ LogicalResult verifyPartitionAttrs(Operation *op) { outputPartitionIdsUnion.insert(outputPartitionIds.begin(), outputPartitionIds.end()); } - if (!llvm::all_of(outputPartitionIdsUnion, [&](int id) { - return partitionIds.contains(id); - })) { + if (!llvm::all_of(outputPartitionIdsUnion, + [&](int id) { return partitionIds.contains(id); })) { return op->emitOpError("partition ids in attr ") << kPartitionAttrName << " must be the union of all partition ids in " @@ -373,7 +370,9 @@ SetVector getPartitionIds(OpOperand *use) { return getPartitionIds(owner); } -bool hasPartition(Operation *op) { return op && op->hasAttr(kPartitionAttrName); } +bool hasPartition(Operation *op) { + return op && op->hasAttr(kPartitionAttrName); +} bool hasWarpSpecializeTag(Operation *op) { return op && op->hasAttr(kWarpSpecializeTagAttrName); diff --git a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp index 6fdb16453bd6..bed05a31c614 100644 --- a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp @@ -1,3 +1,4 @@ +#include "PartitionAttrs.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinOps.h" @@ -8,7 +9,6 @@ #include "mlir/Transforms/RegionUtils.h" #include "nvidia/include/Dialect/NVWS/IR/Dialect.h" #include "nvidia/include/Dialect/NVWS/Transforms/Passes.h" -#include "PartitionAttrs.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Partition.h" #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" diff --git a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp index 9885aa270a58..d61157905756 100644 --- a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp @@ -1,6 +1,6 @@ +#include "PartitionAttrs.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" -#include "PartitionAttrs.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Partition.h" diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/HoistTmemStore.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/HoistTmemStore.cpp index 7d10429b5e31..ad61633927cc 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/HoistTmemStore.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/HoistTmemStore.cpp @@ -21,8 +21,8 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -#include "mlir/Analysis/SliceAnalysis.h" #include "lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionAttrs.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Interfaces/InferIntRangeInterface.h" From 8d462bbc92563f1000eb4f5eaa6d3819badb4780 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 16 Apr 2026 12:42:40 -0700 Subject: [PATCH 3/4] Verify and clear WS partition attrs --- .../AutomaticWarpSpecialization.cpp | 53 ++++++++++++++++--- .../automatic-warp-specialization.mlir | 10 ++-- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp index b2a6b85aa129..592e9278364d 100644 --- a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp @@ -1,3 +1,4 @@ +#include "PartitionAttrs.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" @@ -23,6 +24,25 @@ namespace mlir::triton::gpu { } // namespace mlir::triton::gpu namespace { +struct VerifyWarpSpecializationPartitions + : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + VerifyWarpSpecializationPartitions) + + void runOnOperation() override { + WalkResult result = getOperation().walk([&](scf::ForOp loop) { + if (!loop->hasAttr(kPartitionStagesAttrName)) + return WalkResult::advance(); + if (failed(verifyPartitionedLoop(loop))) { + signalPassFailure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + (void)result; + } +}; + struct AutomaticWarpSpecialization : triton::gpu::impl::TritonGPUAutomaticWarpSpecializationBase< AutomaticWarpSpecialization> { @@ -57,20 +77,38 @@ void multiBufferTMADescriptors(ModuleOp mod, int numStages) { } } +void clearInternalWarpSpecializationAttrs(ModuleOp mod) { + mod.walk([](Operation *op) { + op->removeAttr(kPartitionAttrName); + op->removeAttr(kPartitionOutputsAttrName); + op->removeAttr(kPartitionStagesAttrName); + op->removeAttr(kWarpSpecializeTagAttrName); + }); +} + +std::unique_ptr createVerifyWarpSpecializationPartitionsPass() { + return std::make_unique(); +} + } // namespace void AutomaticWarpSpecialization::runOnOperation() { OpPassManager pm; - pm.addPass(createTritonGPUPartitionScheduling()); - pm.addPass(createNVWSHoistTmemStore()); - pm.addPass(createNVWSInsertAref()); - pm.addPass(createNVWSInsertTmemAref()); + auto addPassWithPartitionVerifier = [&](std::unique_ptr pass) { + pm.addPass(std::move(pass)); + pm.addPass(createVerifyWarpSpecializationPartitionsPass()); + }; + + addPassWithPartitionVerifier(createTritonGPUPartitionScheduling()); + addPassWithPartitionVerifier(createNVWSHoistTmemStore()); + addPassWithPartitionVerifier(createNVWSInsertAref()); + addPassWithPartitionVerifier(createNVWSInsertTmemAref()); // `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic. // FIXME: Re-enable integer range analysis once it is fixed. // pm.addPass(arith::createIntRangeOptimizationsPass()); - pm.addPass(createSCCPPass()); - pm.addPass(createCSEPass()); - pm.addPass(createNVWSLowerAref({numStages})); + addPassWithPartitionVerifier(createSCCPPass()); + addPassWithPartitionVerifier(createCSEPass()); + addPassWithPartitionVerifier(createNVWSLowerAref({numStages})); pm.addPass(createTritonGPUPartitionLoops()); pm.addPass(createNVWSLowerWarpGroup()); pm.addPass(createTritonGPUScheduleLoops()); @@ -80,4 +118,5 @@ void AutomaticWarpSpecialization::runOnOperation() { // Multi-buffer TMA descriptors. We cannot rely on SWP to do it, to support // desc updates in nested loops. multiBufferTMADescriptors(getOperation(), numStages); + clearInternalWarpSpecializationAttrs(getOperation()); } diff --git a/test/TritonGPU/automatic-warp-specialization.mlir b/test/TritonGPU/automatic-warp-specialization.mlir index 5bebd300155f..4d53d54d47ba 100644 --- a/test/TritonGPU/automatic-warp-specialization.mlir +++ b/test/TritonGPU/automatic-warp-specialization.mlir @@ -1,6 +1,10 @@ -// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-hoist-tmem-alloc -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-automatic-warp-specialization=num-stages=2 | FileCheck %s --check-prefix=CHECK --check-prefix=BASE -// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-hoist-tmem-alloc -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-automatic-warp-specialization=num-stages=2 -tritongpu-pipeline | FileCheck %s --check-prefix=CHECK --check-prefix=PIPELINE -// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-hoist-tmem-alloc -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-automatic-warp-specialization=num-stages=2 -tritongpu-pipeline -tritongpu-optimize-partition-warps | FileCheck %s --check-prefix=OPT +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-hoist-tmem-alloc -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-automatic-warp-specialization=num-stages=2 | FileCheck %s --check-prefix=CHECK --check-prefix=BASE --check-prefix=CLEAN +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-hoist-tmem-alloc -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-automatic-warp-specialization=num-stages=2 -tritongpu-pipeline | FileCheck %s --check-prefix=CHECK --check-prefix=PIPELINE --check-prefix=CLEAN +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-hoist-tmem-alloc -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-automatic-warp-specialization=num-stages=2 -tritongpu-pipeline -tritongpu-optimize-partition-warps | FileCheck %s --check-prefix=OPT --check-prefix=CLEAN + +// CLEAN: module +// CLEAN-NOT: ttg.partition +// CLEAN-NOT: ttg.warp_specialize.tag #indices_layout_parent = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> #indices_layout = #ttg.slice<{dim = 0, parent = #indices_layout_parent}> From 89a2cff8225db939cb76a548857db9b7f5e2a8b6 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 16 Apr 2026 12:59:59 -0700 Subject: [PATCH 4/4] Fix WS partition verifier tests --- .../TritonGPU/Transforms/WarpSpecialization/Partition.cpp | 2 ++ test/TritonGPU/partition-verifier-locality.mlir | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp index 200b66b706a3..81822accf469 100644 --- a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp @@ -54,6 +54,8 @@ LogicalResult verifyPartitionAttrs(Operation *op) { Operation *failedOp = nullptr; op->walk([&](Operation *childOp) { + if (isa(childOp)) + return WalkResult::advance(); if (!childOp->hasAttr(kPartitionAttrName)) { failedOp = childOp; return WalkResult::interrupt(); diff --git a/test/TritonGPU/partition-verifier-locality.mlir b/test/TritonGPU/partition-verifier-locality.mlir index aa883a699f50..93ed198f0cd6 100644 --- a/test/TritonGPU/partition-verifier-locality.mlir +++ b/test/TritonGPU/partition-verifier-locality.mlir @@ -6,7 +6,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %lb: i32, %ub: i32, %step: i32) { scf.for %i = %lb to %ub step %step : i32 { %0 = arith.addi %i, %i {ttg.partition = array} : i32 - "use"(%0) : (i32) -> () + "use"(%0) {ttg.partition = array} : (i32) -> () } {ttg.partition.stages = [0, 0], ttg.warp_specialize.tag = 0 : i32, ttg.partition = array} tt.return