Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions test/TritonGPU/amd/amd-canonicalize-pointers-different-bases.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: triton-opt %s -split-input-file -tritonamdgpu-canonicalize-pointers -canonicalize | FileCheck %s

// CHECK-LABEL: tt.func @scf_if_different_bases
// CHECK: [[BASE:%.*]] = arith.select %arg2, %arg0, %arg1 : !tt.ptr<f32>
// CHECK: [[OFFSET:%.*]] = arith.select %arg2, %c16_i32, %c32_i32 : i32
// CHECK: [[PTR:%.*]] = tt.addptr [[BASE]], [[OFFSET]]
// CHECK: tt.load [[PTR]]
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @scf_if_different_bases(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg2: i1) -> f32 {
%c16_i32 = arith.constant 16 : i32
%c32_i32 = arith.constant 32 : i32
%0 = scf.if %arg2 -> (!tt.ptr<f32>) {
%2 = tt.addptr %arg0, %c16_i32 : !tt.ptr<f32>, i32
scf.yield %2 : !tt.ptr<f32>
} else {
%2 = tt.addptr %arg1, %c32_i32 : !tt.ptr<f32>, i32
scf.yield %2 : !tt.ptr<f32>
}
%1 = tt.load %0 : !tt.ptr<f32>
tt.return %1 : f32
}
}

// -----

// CHECK-LABEL: tt.func @select_different_bases
// CHECK: [[BASE:%.*]] = arith.select %arg2, %arg0, %arg1 : !tt.ptr<f32>
// CHECK: [[OFFSET:%.*]] = arith.select %arg2, %c16_i32, %c32_i32 : i32
// CHECK: [[PTR:%.*]] = tt.addptr [[BASE]], [[OFFSET]]
// CHECK: tt.load [[PTR]]
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @select_different_bases(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg2: i1) -> f32 {
%c16_i32 = arith.constant 16 : i32
%c32_i32 = arith.constant 32 : i32
%2 = tt.addptr %arg0, %c16_i32 : !tt.ptr<f32>, i32
%3 = tt.addptr %arg1, %c32_i32 : !tt.ptr<f32>, i32
%4 = arith.select %arg2, %2, %3 : !tt.ptr<f32>
%5 = tt.load %4 : !tt.ptr<f32>
tt.return %5 : f32
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -462,18 +462,29 @@ struct FatPointers {

friend bool operator==(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
return lhs.canNarrow == rhs.canNarrow &&
lhs.attributes == rhs.attributes &&
lhs.smallTensorBase == rhs.smallTensorBase;
lhs.isSmallTensor == rhs.isSmallTensor &&
lhs.attributes.getArrayRef() == rhs.attributes.getArrayRef();
}

friend bool operator!=(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
return !(lhs == rhs);
}

llvm::DenseMap<StringRef, Attribute> attributes;
// If the fat-pointer points to somewhere in a small-tensor, keep track the
// base of the tensor.
Value smallTensorBase;
static FatPtrAttrs intersect(const FatPtrAttrs &lhs,
const FatPtrAttrs &rhs) {
FatPtrAttrs result;
result.canNarrow = lhs.canNarrow && rhs.canNarrow;
result.isSmallTensor = lhs.isSmallTensor && rhs.isSmallTensor;
for (const auto &attr : lhs.attributes) {
auto it = rhs.attributes.find(attr.first);
if (it != rhs.attributes.end() && it->second == attr.second)
result.attributes[attr.first] = attr.second;
}
return result;
}

llvm::SmallMapVector<StringRef, Attribute, 2> attributes;
bool isSmallTensor = false;
bool canNarrow = false;
};

Expand Down Expand Up @@ -563,7 +574,7 @@ Value createTensorPointer(RewriterBase &rewriter, Value basePtr, Value offset,
auto addPtrOp =
tt::AddPtrOp::create(rewriter, loc, basePtr.getType(), basePtr, offset);
for (const auto &attribute : fatPtrAttrs.attributes)
addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond());
addPtrOp->setAttr(attribute.first, attribute.second);
return addPtrOp.getResult();
}

Expand All @@ -585,7 +596,7 @@ Value createTensorPointer(RewriterBase &rewriter, Value basePtr, Value offset,
tt::AddPtrOp::create(rewriter, loc, tensorPtrType, tensorPtr, offset);

for (const auto &attribute : fatPtrAttrs.attributes)
addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond());
addPtrOp->setAttr(attribute.first, attribute.second);
return addPtrOp.getResult();
}

Expand Down Expand Up @@ -745,7 +756,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
RewriterBase::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(addPtrOp);

if (fatPtrs.at({fatPtrBase, fatPtrOffset}).smallTensorBase)
if (fatPtrs.at({fatPtrBase, fatPtrOffset}).isSmallTensor)
return rewriteSmallTensorPtr(addPtrOp, adaptor, rewriter);

// Query all discardable attributes that we want to preserve
Expand Down Expand Up @@ -861,7 +872,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
const auto &oldAttr = fatPtrs.at({fatPtrBase, fatPtrOffset});

LDBG("smal-tensor addPtr: " << addPtrOp);
LDBG(" - with tensor-base: " << oldAttr.smallTensorBase);
LDBG(" - isSmallTensor: " << oldAttr.isSmallTensor);
LDBG(" - with originl offset: " << origOffset);
LDBG(" - fatPtr base: " << fatPtrBase);
LDBG(" - fatPtr offst: " << fatPtrOffset);
Expand Down Expand Up @@ -1351,17 +1362,6 @@ class ConvertArithSelectOp
// select of base and offset
ValueRange fatPtrFalse = adaptor.getFalseValue();
ValueRange fatPtrTrue = adaptor.getTrueValue();
// Simple case of a scalar select: update the base pointer
if (!isa<RankedTensorType>(selectOp.getType())) {
auto newSelectOp = arith::SelectOp::create(
rewriter, selectOp.getLoc(), selectOp.getType(),
selectOp.getCondition(), fatPtrTrue[0], selectOp.getFalseValue());
rewriter.replaceOpWithMultiple(selectOp, {{newSelectOp, fatPtrTrue[1]}});
fatPtrs[{newSelectOp, /*fatPtrOffset*/ fatPtrTrue[1]}] =
fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]});
return success();
}

// Rewrite to select(fatBaseT, fatBaseF) and select(fatOffsetT, fatOffsetF)
auto newBase = arith::SelectOp::create(rewriter, selectOp.getLoc(),
selectOp.getCondition(),
Expand All @@ -1370,12 +1370,10 @@ class ConvertArithSelectOp
selectOp.getCondition(),
fatPtrTrue[1], fatPtrFalse[1]);

assert((fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}) ==
fatPtrs.at({fatPtrFalse[0], fatPtrFalse[1]})) &&
"expected can narrow to be the same for both fatPtrT and fatPtrF");

rewriter.replaceOpWithMultiple(selectOp, {{newBase, newOffset}});
fatPtrs[{newBase, newOffset}] = fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]});
fatPtrs[{newBase, newOffset}] = FatPointers::FatPtrAttrs::intersect(
fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}),
fatPtrs.at({fatPtrFalse[0], fatPtrFalse[1]}));

return success();
}
Expand Down Expand Up @@ -1423,14 +1421,6 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
assert(i < ifOp.thenYield().getNumOperands() &&
i + 1 < ifOp.thenYield().getNumOperands() &&
"expected idx to be within bounds of IfOp's results");
Value thenFatPtrBase = ifOp.thenYield().getOperand(i);
Value thenFatPtrOffset = ifOp.thenYield().getOperand(i + 1);
Value elseFatPtrBase = ifOp.elseYield().getOperand(i);
Value elseFatPtrOffset = ifOp.elseYield().getOperand(i + 1);
assert((fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}) ==
fatPtrs.at({elseFatPtrBase, elseFatPtrOffset})) &&
"expected then fat ptr canNarrow and else fat ptr canNarrow "
"to be equal");
}
}
}
Expand All @@ -1456,8 +1446,17 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
for (int64_t idx : yieldPtrOffsets) {
Value thenFatPtrBase = newIfOp.thenYield().getOperand(idx);
Value thenFatPtrOffset = newIfOp.thenYield().getOperand(idx + 1);
fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] =
fatPtrs.at({thenFatPtrBase, thenFatPtrOffset});
const auto &thenAttrs = fatPtrs.at({thenFatPtrBase, thenFatPtrOffset});
if (withElseRegion) {
Value elseFatPtrBase = newIfOp.elseYield().getOperand(idx);
Value elseFatPtrOffset = newIfOp.elseYield().getOperand(idx + 1);
const auto &elseAttrs = fatPtrs.at({elseFatPtrBase, elseFatPtrOffset});
fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] =
FatPointers::FatPtrAttrs::intersect(thenAttrs, elseAttrs);
} else {
fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] =
thenAttrs;
}
}

ResultRange results = newIfOp.getResults();
Expand Down Expand Up @@ -1697,7 +1696,7 @@ struct InitFuncPtrArgs : OpRewritePattern<tt::FuncOp> {
rewriter.replaceAllUsesExcept(arg, dummyCast.getResult(0), dummyCast);
fatPtrs[{arg, zeroOffset}].canNarrow = true;
if (bitness != 64)
fatPtrs[{arg, zeroOffset}].smallTensorBase = arg;
fatPtrs[{arg, zeroOffset}].isSmallTensor = true;
}

newOp->setDiscardableAttr(kInitFuncArgsRewritten, rewriter.getUnitAttr());
Expand Down
Loading