Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
constexpr static char AttrTargetName[] = "ttg.target";
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
// FIXME: rename to match above
constexpr static char kPartitionAttrName[] = "ttg.partition";
constexpr static char kPartitionOutputsAttrName[] = "ttg.partition.outputs";
constexpr static char kPartitionStagesAttrName[] = "ttg.partition.stages";
constexpr static char kWarpSpecializeTagAttrName[] = "ttg.warp_specialize.tag";

// Find the contextual number of warps on which this operation is executed.
int lookupNumWarps(Operation *op);
Expand Down Expand Up @@ -335,13 +330,6 @@ LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,
ShapedType dstTy);
// Verify a memory allocation operation.
LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy);

SetVector<int> getPartitionIds(Operation *op);
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op);
SetVector<int> getPartitionIds(OpOperand *use);
bool hasPartition(Operation *op);
bool hasWarpSpecializeTag(Operation *op);
std::optional<int> getWarpSpecializeTag(Operation *op);
/// Returns the size in bytes of a scalar type when stored in shared memory.
size_t getSharedMemorySize(Type type);

Expand Down
181 changes: 0 additions & 181 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <numeric>
#include <utility>

#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
Expand Down Expand Up @@ -4051,135 +4050,6 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
<< " which is expected only on `module` or `tt.func` ops";
}

// Verify that all ops in a tt.warp_specialize op have partition ids
if (attr.getName() == "tt.warp_specialize") {
if (!isa<scf::ForOp>(op)) {
return op->emitOpError("has unexpected attribute ")
<< attr.getName() << " which is expected only on `scf.for` ops";
}
Operation *failedOp = nullptr;
op->walk([&](Operation *childOp) {
if (!childOp->hasAttr(kPartitionAttrName)) {
failedOp = childOp;
WalkResult::interrupt();
}
});
if (failedOp) {
return failedOp->emitOpError("does not have expected attribute ")
<< kPartitionAttrName
<< " which is expected on all child ops of an op with "
"attribute `tt.warp_specialize`";
}
}

// Verify that partition id lists are non-empty, sorted and have no duplicates
auto verifyPartitionIds =
[&](const ArrayRef<int> &partitionIds) -> LogicalResult {
SetVector<int> idSet;
for (auto id : partitionIds) {
if (idSet.contains(id))
return op->emitOpError("has duplicated partition ids in attribute ")
<< attr.getName();
idSet.insert(id);
}
if (idSet.empty())
return op->emitOpError("has no partition ids in attribute ")
<< attr.getName();
auto ids = idSet.takeVector();
SmallVector<int> sortedIds(ids.begin(), ids.end());
std::sort(sortedIds.begin(), sortedIds.end());
if (ids != sortedIds)
return op->emitOpError("partition ids not in sorted order in attribute ")
<< attr.getName();
return success();
};

if (attr.getName() == kPartitionAttrName) {
auto result = verifyPartitionIds(
cast<DenseI32ArrayAttr>(attr.getValue()).asArrayRef());
if (failed(result))
return result;
}
if (attr.getName() == kPartitionOutputsAttrName) {
auto arrayAttr = cast<ArrayAttr>(attr.getValue());
for (auto idx = 0; idx < arrayAttr.size(); idx++) {
auto result = verifyPartitionIds(
cast<DenseI32ArrayAttr>(arrayAttr[idx]).asArrayRef());
if (failed(result))
return result;
}
}

// Verify that op partitions include partitions of all child ops
if (attr.getName() == kPartitionAttrName && op->getNumRegions() != 0) {
SetVector<int> expectedIds;
for (auto &region : op->getRegions()) {
for (auto &block : region.getBlocks()) {
for (auto &childOp : block.getOperations()) {
if (isa<scf::YieldOp, ub::PoisonOp>(childOp)) {
// yield ops and ub.poison do not need partition ids
continue;
}
if (!childOp.hasAttr(kPartitionAttrName))
return childOp.emitOpError("does not have expected attribute ")
<< kPartitionAttrName
<< " which is expected for ops whose parent has partitions";
auto ids = getPartitionIds(&childOp);
expectedIds.insert(ids.begin(), ids.end());
}
}
}
auto partitionIds = getPartitionIds(op);
for (auto id : expectedIds) {
if (!partitionIds.contains(id)) {
return op->emitOpError("partition ids in attr ")
<< attr.getName()
<< " does not contain partition ids of all child ops";
}
}
}

if (attr.getName() == kPartitionOutputsAttrName) {
if (!isa<scf::ForOp, scf::IfOp, triton::ReduceOp>(op))
return op->emitOpError("has unexpected attribute ") << attr.getName();

// Verify that number of output partitions matches number of For/If results
size_t numResults = 0;
if (isa<scf::ForOp>(op)) {
numResults = cast<scf::ForOp>(op).getResults().size();
} else if (isa<scf::IfOp>(op)) {
numResults = cast<scf::IfOp>(op).getResults().size();
} else {
numResults = cast<triton::ReduceOp>(op).getResults().size();
}

if (cast<ArrayAttr>(attr.getValue()).size() != numResults) {
return op->emitOpError("does not have expected number of output "
"partition sets in attr ")
<< attr.getName() << "; should match number of results";
}

// Verify that union of op output partitions is a subset of op partitions
if (!op->hasAttr(kPartitionAttrName))
return op->emitOpError("does not have expected attribute ")
<< kPartitionAttrName << " which is expected for ops with attr "
<< kPartitionOutputsAttrName;
auto partitionIds = getPartitionIds(op);

SetVector<int> outputPartitionIdsUnion;
for (auto outputPartitionIds : getPartitionOutputs(op)) {
outputPartitionIdsUnion.insert(outputPartitionIds.begin(),
outputPartitionIds.end());
}
if (!std::all_of(outputPartitionIdsUnion.begin(),
outputPartitionIdsUnion.end(),
[&](int id) { return partitionIds.contains(id); })) {
return op->emitOpError("partition ids in attr ")
<< kPartitionAttrName
<< " must be the union of all partition ids in " << attr.getName();
}
}

return success();
}

Expand Down Expand Up @@ -4414,57 +4284,6 @@ SmallVector<int64_t> triton::gpu::getTMABlockShape(
mode);
}

SetVector<int> triton::gpu::getPartitionIds(Operation *op) {
auto attrs = op->getAttr(kPartitionAttrName);
SmallVector<int> partitionIds;
for (auto id : cast<DenseI32ArrayAttr>(attrs).asArrayRef()) {
partitionIds.push_back(id);
}
std::sort(partitionIds.begin(), partitionIds.end());
return SetVector<int>(partitionIds.begin(), partitionIds.end());
}

SmallVector<SetVector<int>, 4> triton::gpu::getPartitionOutputs(Operation *op) {
SmallVector<SetVector<int>, 4> partitionOutputsIds;
if (op->getNumResults() == 0) {
return partitionOutputsIds;
}
assert(op->hasAttr(kPartitionOutputsAttrName));
auto arrayAttr = cast<ArrayAttr>(op->getAttr(kPartitionOutputsAttrName));
for (auto attr : arrayAttr) {
auto ids = cast<DenseI32ArrayAttr>(attr).asArrayRef();
partitionOutputsIds.push_back(SetVector<int>(ids.begin(), ids.end()));
}
return partitionOutputsIds;
}

SetVector<int> triton::gpu::getPartitionIds(OpOperand *use) {
auto owner = use->getOwner();
if (isa<scf::YieldOp>(owner)) {
return getPartitionOutputs(owner->getParentOp())[use->getOperandNumber()];
} else if (scf::ForOp forOp = dyn_cast<scf::ForOp>(owner)) {
int idx = use->getOperandNumber() - forOp.getNumControlOperands();
return idx >= 0 ? getPartitionOutputs(owner)[idx] : getPartitionIds(forOp);
} else {
return getPartitionIds(owner);
}
}

bool triton::gpu::hasPartition(Operation *op) {
return op && op->hasAttr(kPartitionAttrName);
}

bool triton::gpu::hasWarpSpecializeTag(Operation *op) {
return op && op->hasAttr(kWarpSpecializeTagAttrName);
}

std::optional<int> triton::gpu::getWarpSpecializeTag(Operation *op) {
if (hasWarpSpecializeTag(op)) {
return cast<IntegerAttr>(op->getAttr(kWarpSpecializeTagAttrName)).getInt();
}
return std::nullopt;
}

PaddedSharedEncodingAttr triton::gpu::getPaddedEncoding(Attribute encoding) {
if (!encoding)
return nullptr;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "PartitionAttrs.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -23,6 +24,25 @@ namespace mlir::triton::gpu {
} // namespace mlir::triton::gpu

namespace {
struct VerifyWarpSpecializationPartitions
: PassWrapper<VerifyWarpSpecializationPartitions, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
VerifyWarpSpecializationPartitions)

void runOnOperation() override {
WalkResult result = getOperation().walk([&](scf::ForOp loop) {
if (!loop->hasAttr(kPartitionStagesAttrName))
return WalkResult::advance();
if (failed(verifyPartitionedLoop(loop))) {
signalPassFailure();
return WalkResult::interrupt();
}
return WalkResult::advance();
});
(void)result;
}
};

struct AutomaticWarpSpecialization
: triton::gpu::impl::TritonGPUAutomaticWarpSpecializationBase<
AutomaticWarpSpecialization> {
Expand Down Expand Up @@ -57,20 +77,38 @@ void multiBufferTMADescriptors(ModuleOp mod, int numStages) {
}
}

void clearInternalWarpSpecializationAttrs(ModuleOp mod) {
mod.walk([](Operation *op) {
op->removeAttr(kPartitionAttrName);
op->removeAttr(kPartitionOutputsAttrName);
op->removeAttr(kPartitionStagesAttrName);
op->removeAttr(kWarpSpecializeTagAttrName);
});
}

std::unique_ptr<Pass> createVerifyWarpSpecializationPartitionsPass() {
return std::make_unique<VerifyWarpSpecializationPartitions>();
}

} // namespace

void AutomaticWarpSpecialization::runOnOperation() {
OpPassManager pm;
pm.addPass(createTritonGPUPartitionScheduling());
pm.addPass(createNVWSHoistTmemStore());
pm.addPass(createNVWSInsertAref());
pm.addPass(createNVWSInsertTmemAref());
auto addPassWithPartitionVerifier = [&](std::unique_ptr<Pass> pass) {
pm.addPass(std::move(pass));
pm.addPass(createVerifyWarpSpecializationPartitionsPass());
};

addPassWithPartitionVerifier(createTritonGPUPartitionScheduling());
addPassWithPartitionVerifier(createNVWSHoistTmemStore());
addPassWithPartitionVerifier(createNVWSInsertAref());
addPassWithPartitionVerifier(createNVWSInsertTmemAref());
// `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic.
// FIXME: Re-enable integer range analysis once it is fixed.
// pm.addPass(arith::createIntRangeOptimizationsPass());
pm.addPass(createSCCPPass());
pm.addPass(createCSEPass());
pm.addPass(createNVWSLowerAref({numStages}));
addPassWithPartitionVerifier(createSCCPPass());
addPassWithPartitionVerifier(createCSEPass());
addPassWithPartitionVerifier(createNVWSLowerAref({numStages}));
pm.addPass(createTritonGPUPartitionLoops());
pm.addPass(createNVWSLowerWarpGroup());
pm.addPass(createTritonGPUScheduleLoops());
Expand All @@ -80,4 +118,5 @@ void AutomaticWarpSpecialization::runOnOperation() {
// Multi-buffer TMA descriptors. We cannot rely on SWP to do it, to support
// desc updates in nested loops.
multiBufferTMADescriptors(getOperation(), numStages);
clearInternalWarpSpecializationAttrs(getOperation());
}
Loading
Loading