Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1004,11 +1004,6 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
"triton::PaddingOption":$padding)>
];

let extraClassDeclaration = [{
ArrayRef<int64_t> getTensorShape() {
return getType().getBlockType().getShape();
}
}];
}

// The following ops, including `call`, `func`, and `return` are copied and modified from
Expand Down
41 changes: 32 additions & 9 deletions include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,48 @@ def TT_TensorDescInterface : TypeInterface<"TensorDescInterface"> {

let methods = [
InterfaceMethod<
/*desc=*/"Returns the block type of the tensor descriptor",
/*desc=*/"Returns the shape of the descriptor block",
/*retType=*/"llvm::ArrayRef<int64_t>",
/*methodName=*/"getShape",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/"Returns the element type of the descriptor block",
/*retType=*/"mlir::Type",
/*methodName=*/"getElementType",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/"Returns the optional shared memory layout encoding",
/*retType=*/"mlir::Attribute",
/*methodName=*/"getSharedLayout",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/"Returns a block tensor type constructed from shape and element type",
/*retType=*/"mlir::RankedTensorType",
/*methodName=*/"getBlockType",
/*args=*/(ins)
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImpl=*/[{
return mlir::RankedTensorType::get($_type.getShape(),
$_type.getElementType());
}]
>,
InterfaceMethod<
/*desc=*/"Returns the block type with signless integer element type",
/*desc=*/"Returns a block tensor type constructed with signless integer element type",
/*retType=*/"mlir::RankedTensorType",
/*methodName=*/"getSignlessBlockType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImpl=*/[{
auto resTy = $_type.getBlockType();
if (auto intTy = llvm::dyn_cast<mlir::IntegerType>(resTy.getElementType())) {
auto width = resTy.getElementTypeBitWidth();
auto signlessTy = mlir::IntegerType::get($_type.getContext(), width);
resTy = resTy.clone(signlessTy);
auto shape = $_type.getShape();
auto elemTy = $_type.getElementType();
if (auto intTy = llvm::dyn_cast<mlir::IntegerType>(elemTy)) {
auto width = intTy.getWidth();
elemTy = mlir::IntegerType::get($_type.getContext(), width);
}
return resTy;
return mlir::RankedTensorType::get(shape, elemTy);
}]
>,
];
Expand Down
47 changes: 39 additions & 8 deletions include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,58 @@ def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", [TT_TensorDesc
A portable abstraction for TMA descriptors.
This is the base tensor descriptor type for tiled tensor memory access.

Shape and elementType describe the block dimensions and data type.
The optional sharedLayout attribute carries the shared memory encoding
(e.g. swizzle pattern) that is assigned during lowering.

For specialized access patterns like im2col, see TensorDescIm2ColType
in the TritonNvidiaGPU dialect.
}];

let parameters = (ins
"RankedTensorType":$blockType
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
OptionalParameter<"Attribute">:$sharedLayout
);

let assemblyFormat = "`<` $blockType `>`";

let builders = [
// Builder from shape + elementType + sharedLayout
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$sharedLayout
), [{
return $_get(elementType.getContext(), shape, elementType, sharedLayout);
}]>,
// Builder with signedness
TypeBuilder<(ins "RankedTensorType":$blockType, "bool":$isSigned), [{
if (auto intTy = llvm::dyn_cast<IntegerType>(blockType.getElementType())) {
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"bool":$isSigned
), [{
if (auto intTy = llvm::dyn_cast<IntegerType>(elementType)) {
auto sem = isSigned ? IntegerType::Signed : IntegerType::Unsigned;
auto elemTy = IntegerType::get($_ctxt, intTy.getWidth(), sem);
blockType = blockType.clone(elemTy);
elementType = IntegerType::get(elementType.getContext(), intTy.getWidth(), sem);
}
return Base::get($_ctxt, blockType);
return $_get(elementType.getContext(), shape, elementType, Attribute{});
}]>,
// Builder with signedness and shared layout
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$sharedLayout,
"bool":$isSigned
), [{
if (auto intTy = llvm::dyn_cast<IntegerType>(elementType)) {
auto sem = isSigned ? IntegerType::Signed : IntegerType::Unsigned;
elementType = IntegerType::get(elementType.getContext(), intTy.getWidth(), sem);
}
return $_get(elementType.getContext(), shape, elementType, sharedLayout);
}]>,
];

let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

#endif
28 changes: 15 additions & 13 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def TTNG_TensorDescIm2ColType : TTNG_TypeDef<"TensorDescIm2Col", "tensordesc_im2
operations.

Parameters:
- blockType: The shape and element type of the data block being accessed
- shape: The block dimensions
- elementType: The element type of the data block
- sharedLayout: Optional shared memory encoding (swizzle pattern, etc.)

This type implements TensorDescInterface, sharing common operations with
the tiled TensorDescType in the base Triton dialect.
Expand All @@ -62,28 +64,28 @@ def TTNG_TensorDescIm2ColType : TTNG_TypeDef<"TensorDescIm2Col", "tensordesc_im2
}];

let parameters = (ins
"RankedTensorType":$blockType
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
OptionalParameter<"Attribute">:$sharedLayout
);

let assemblyFormat = [{
`<` $blockType `>`
}];

let builders = [
// Builder with signedness for integer types
TypeBuilder<(ins
"RankedTensorType":$blockType,
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$sharedLayout,
"bool":$isSigned
), [{
if (auto intTy = llvm::dyn_cast<IntegerType>(blockType.getElementType())) {
if (auto intTy = llvm::dyn_cast<IntegerType>(elementType)) {
auto sem = isSigned ? IntegerType::Signed : IntegerType::Unsigned;
auto elemTy = IntegerType::get($_ctxt, intTy.getWidth(), sem);
blockType = blockType.clone(elemTy);
elementType = IntegerType::get(elementType.getContext(), intTy.getWidth(), sem);
}
return Base::get($_ctxt, blockType);
return $_get(elementType.getContext(), shape, elementType, sharedLayout);
}]>
];

let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@ inline SmallVector<int64_t> getTMABlockShape(Attribute encoding,
mmaEnc.getFp4Padded(), mmaEnc.getTransposed(), packedSize, mode);
}

inline SmallVector<int64_t>
getTMABlockShape(RankedTensorType ty, bool packedSize, gpu::TMAMode mode) {
inline SmallVector<int64_t> getTMABlockShape(triton::gpu::MemDescType ty,
bool packedSize,
gpu::TMAMode mode) {
auto shapePerCTA = gpu::getShapePerCTA(ty);
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize, mode);
}

inline SmallVector<int64_t> getTMABlockShape(triton::gpu::MemDescType ty,
inline SmallVector<int64_t> getTMABlockShape(triton::TensorDescInterface ty,
bool packedSize,
gpu::TMAMode mode) {
auto shapePerCTA = gpu::getShapePerCTA(ty);
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize, mode);
auto shapePerCTA = gpu::getShapePerCTA(ty.getSharedLayout(), ty.getShape());
return getTMABlockShape(ty.getSharedLayout(), shapePerCTA, packedSize, mode);
}

FailureOr<int> getTMASwizzleMode(Location loc, triton::TensorDescInterface ty);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonInstrumentToLLVM/GSanToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ struct GSanTensorDescInfoOpConversion
"expected byte-addressable element");
}

unsigned rank = descTy.getBlockType().getRank();
unsigned rank = descTy.getShape().size();
unsigned elemBytes = elemTy.getIntOrFloatBitWidth() / 8;
if (op->getNumResults() != 1 + 2 * rank) {
return rewriter.notifyMatchFailure(
Expand Down
4 changes: 1 addition & 3 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1094,9 +1094,7 @@ void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
}
auto elemTy = ptrTy.getPointeeType();
SmallVector<int64_t> blockShape64(blockShape);
auto blockTy = RankedTensorType::get(blockShape64, elemTy);
auto descTy =
TensorDescType::get(builder.getContext(), blockTy, isSignedInteger);
auto descTy = TensorDescType::get(blockShape64, elemTy, isSignedInteger);
auto paddingAttr = PaddingOptionAttr::get(builder.getContext(), padding);
return build(builder, state, descTy, base, shape, strides, paddingAttr);
}
Expand Down
36 changes: 36 additions & 0 deletions lib/Dialect/Triton/IR/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,42 @@ void TritonDialect::registerTypes() {
>();
}

// Format: !tt.tensordesc<128x64xf16>
// !tt.tensordesc<128x64xf16, #shared>
Type TensorDescType::parse(AsmParser &parser) {
if (failed(parser.parseLess()))
return Type();

SmallVector<int64_t> shape;
if (failed(parser.parseDimensionList(shape, /*allowDynamic=*/false)))
return Type();

Type elementType;
if (failed(parser.parseType(elementType)))
return Type();

Attribute sharedLayout;
if (succeeded(parser.parseOptionalComma())) {
if (failed(parser.parseAttribute(sharedLayout)))
return Type();
}

if (failed(parser.parseGreater()))
return Type();

return TensorDescType::get(shape, elementType, sharedLayout);
}

void TensorDescType::print(AsmPrinter &printer) const {
printer << "<";
for (auto dim : getShape())
printer << dim << "x";
printer << getElementType();
if (getSharedLayout())
printer << ", " << getSharedLayout();
printer << ">";
}

Type PointerType::parse(AsmParser &parser) {
if (parser.parseLess())
return Type();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ struct Descriptor {
};

Descriptor unpackDescriptor(TensorDescType type, ValueRange pack) {
int rank = type.getBlockType().getRank();
int rank = type.getShape().size();
assert(pack.size() == 1 + 2 * static_cast<size_t>(rank) + 2 &&
"Expected tensor descriptors to consist of a pointer, "
"followed by 'rank' shape values and 'rank' stride values, "
Expand Down Expand Up @@ -328,7 +328,7 @@ struct RewriteLoadPattern : OpConversionPattern<triton::DescriptorLoadOp> {
matchAndRewrite(triton::DescriptorLoadOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
const auto blockShape = op.getDesc().getType().getBlockType().getShape();
const auto blockShape = op.getDesc().getType().getShape();
auto descTy = op.getDesc().getType();
auto desc = unpackDescriptor(descTy, adaptor.getDesc());
auto offsets = castToI64(rewriter, op.getIndices());
Expand All @@ -340,7 +340,7 @@ struct RewriteLoadPattern : OpConversionPattern<triton::DescriptorLoadOp> {
newLoad->setAttrs(filterSegmentSizes(op->getAttrs()));

Value result = newLoad.getResult();
if (descTy.getBlockType().getElementType().isF32()) {
if (descTy.getElementType().isF32()) {

auto ifOp = scf::IfOp::create(rewriter, loc, result.getType(),
desc.roundF32ToTF32, /*withElse=*/true);
Expand All @@ -367,7 +367,7 @@ struct RewriteStorePattern : OpConversionPattern<triton::DescriptorStoreOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto descTy = op.getDesc().getType();
const auto blockShape = descTy.getBlockType().getShape();
const auto blockShape = descTy.getShape();
auto desc = unpackDescriptor(descTy, adaptor.getDesc());
auto offsets = castToI64(rewriter, op.getIndices());

Expand Down Expand Up @@ -458,7 +458,7 @@ struct RewriteScatterPattern

std::optional<RMWOp> translateReduceKind(DescriptorReduceKind kind,
TensorDescType ty) {
auto scalarTy = ty.getBlockType().getElementType();
auto scalarTy = ty.getElementType();
switch (kind) {
case DescriptorReduceKind::ADD:
return scalarTy.isInteger() ? RMWOp::ADD : RMWOp::FADD;
Expand Down Expand Up @@ -496,15 +496,15 @@ struct RewriteReducePattern : OpConversionPattern<triton::DescriptorReduceOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto descTy = op.getDesc().getType();
const auto blockShape = descTy.getBlockType().getShape();
const auto blockShape = descTy.getShape();
auto desc = unpackDescriptor(descTy, adaptor.getDesc());
auto offsets = castToI64(rewriter, op.getIndices());
auto rmwOp = translateReduceKind(op.getKind(), descTy);
if (!rmwOp) {
std::string msgstring;
llvm::raw_string_ostream msg(msgstring);
msg << "Cannot fallback on descriptor atomic op, unsupported for type "
<< descTy.getBlockType().getElementType();
<< descTy.getElementType();
return op->emitError(msgstring);
}

Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4444,6 +4444,8 @@ std::optional<int> triton::gpu::getWarpSpecializeTag(Operation *op) {
}

PaddedSharedEncodingAttr triton::gpu::getPaddedEncoding(Attribute encoding) {
if (!encoding)
return nullptr;
if (auto padded = dyn_cast<PaddedSharedEncodingAttr>(encoding))
return padded;
if (auto partitioned = dyn_cast<PartitionedSharedEncodingAttr>(encoding))
Expand Down
12 changes: 6 additions & 6 deletions lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ TensorDescType getTensorDescTypeWithEncoding(Operation *op,
Attribute encoding) {
auto sharedEnc = cast<SharedEncodingTrait>(encoding);
encoding = updateEncodingForShape(op, sharedEnc, existingTy);
auto blockTy = existingTy.cloneWithEncoding(encoding);
return TensorDescType::get(existingTy.getContext(), blockTy);
return TensorDescType::get(existingTy.getShape(), existingTy.getElementType(),
encoding);
}

struct UseInfo {
Expand Down Expand Up @@ -283,7 +283,7 @@ AssignDescriptorMemoryLayouts::getUseInfo(Operation *op) {
: load.getType().getEncoding();
info.cgaLayout = getCGALayout(encoding);
auto shape = load.getResult().getType().getShape();
auto rank = load.getDesc().getType().getBlockType().getRank();
auto rank = load.getDesc().getType().getShape().size();
info.shape = expandToRank(shape, rank);
return info;
}
Expand All @@ -294,7 +294,7 @@ AssignDescriptorMemoryLayouts::getUseInfo(Operation *op) {
: gather.getType().getEncoding();
info.cgaLayout = getCGALayout(encoding);
auto shape = gather.getResult().getType().getShape();
auto rank = gather.getDesc().getType().getBlockType().getRank();
auto rank = gather.getDesc().getType().getShape().size();
info.shape = expandToRank(shape, rank);
return info;
}
Expand All @@ -303,7 +303,7 @@ AssignDescriptorMemoryLayouts::getUseInfo(Operation *op) {
auto encoding = store.getSrc().getType().getEncoding();
info.cgaLayout = getCGALayout(encoding);
auto shape = store.getSrc().getType().getShape();
auto rank = store.getDesc().getType().getBlockType().getRank();
auto rank = store.getDesc().getType().getShape().size();
info.shape = expandToRank(shape, rank);
return info;
}
Expand Down Expand Up @@ -353,7 +353,7 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) {
auto itr = valueToEncodingInfo.find(typedVal);
if (itr != valueToEncodingInfo.end())
info = combineEncodings(*itr->second, info,
typedVal.getType().getBlockType().getRank());
typedVal.getType().getShape().size());
}

auto einfo = internEncoding(encodings, info);
Expand Down
Loading
Loading