Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/llvm-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ jobs:
run: |
# if this step crashes, it can leave behind a stale docker container
docker container prune -f
docker rmi -f $(docker images -q)

docker build --tag llvm-build --build-arg llvm_dir=llvm-project \
-f llvm-build/.github/workflows/llvm-build/Dockerfile .
Expand Down
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
49af6502c6dcb4a7f7520178bd14df396f78240c
5e5a22caf88ac1ccfa8dc5720295fdeba0ad9372
19 changes: 9 additions & 10 deletions include/triton/Dialect/NVGPU/IR/NVGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType

def I8Ptr_global : LLVM_IntPtrBase<8, 1>;
def I8Ptr_shared : LLVM_IntPtrBase<8, 3>;
def I64Ptr_shared : LLVM_IntPtrBase<64, 3>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;

class NVGPU_Op<string mnemonic, list<Trait> traits = []> :
LLVM_OpBase<NVGPU_Dialect, mnemonic, traits>;
Expand All @@ -55,7 +54,7 @@ def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group",
}

def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier_init", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, I32Attr:$count);
let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$pred, I32Attr:$count);
let assemblyFormat = "$mbarrier `,` $pred attr-dict `:` type($mbarrier)";
}

Expand All @@ -71,12 +70,12 @@ def MBarrier_ArriveTypeAttr : I32EnumAttr<"MBarriveType",
}

def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier_arrive", []> {
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, Optional<I32>:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr<I32Attr, "0">:$txCount);
let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$pred, Optional<I32>:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr<I32Attr, "0">:$txCount);
let assemblyFormat = "$mbarrier `,` $pred (`,` $ctaId^)? attr-dict `:` type($mbarrier)";
}

def NVGPU_MBarrierWaitOp : NVGPU_Op<"mbarrier_wait", []> {
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$phase);
let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$phase);
let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type(operands)";
}

Expand Down Expand Up @@ -116,13 +115,13 @@ def NVGPU_WGMMADescCreateOp : NVGPU_Op<"wgmma_desc_create", []> {
}

def NVGPU_TMALoadTiledOp : NVGPU_Op<"tma_load_tiled", [AttrSizedOperandSegments]> {
let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc,
let arguments = (ins LLVM_PointerShared:$dst, LLVM_PointerShared:$mbarrier, LLVM_PointerGlobal:$tmaDesc, I64:$l2Desc,
I1:$pred, Variadic<I32>:$coords, Optional<I16>:$mcastMask);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}

def NVGPU_TMALoadIm2colOp : NVGPU_Op<"tma_load_im2col", []> {
let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc, LLVM_AnyStruct:$im2colOffsets, I1:$pred, Variadic<I32>:$coords, I16Attr:$mcastMask);
let arguments = (ins LLVM_PointerShared:$dst, LLVM_PointerShared:$mbarrier, LLVM_PointerGlobal:$tmaDesc, I64:$l2Desc, LLVM_AnyStruct:$im2colOffsets, I1:$pred, Variadic<I32>:$coords, I16Attr:$mcastMask);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}

Expand Down Expand Up @@ -217,12 +216,12 @@ def NVGPU_ClusterWaitOp : NVGPU_Op<"cluster_wait", []> {
}

def NVGPU_TMAStoreTiledOp : NVGPU_Op<"tma_store_tiled", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins I8Ptr_global:$tmaDesc, I8Ptr_shared:$src, I1:$pred, Variadic<I32>:$coords);
let arguments = (ins LLVM_PointerGlobal:$tmaDesc, LLVM_PointerShared:$src, I1:$pred, Variadic<I32>:$coords);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}

def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins I8Ptr_shared:$addr, Variadic<I32>:$datas);
let arguments = (ins LLVM_PointerShared:$addr, Variadic<I32>:$datas);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}

Expand Down
25 changes: 15 additions & 10 deletions lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,16 @@ struct AllocMBarrierOpConversion : public ConvertTritonGPUOpToLLVMPattern<
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult());
auto resultTy = op.getType();
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
Type elemPtrTy;
Type elemPtrTy = ptr_ty(rewriter.getContext(), 3);
Type llvmElemTy;
if (resultTensorTy) {
auto llvmElemTy =
llvmElemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType());
elemPtrTy = ptr_ty(llvmElemTy, 3);
} else {
elemPtrTy = getTypeConverter()->convertType(resultTy);
auto resultPtrTy = resultTy.dyn_cast<triton::PointerType>();
assert(resultPtrTy && "Unknown type for AllocMBarrierOp");
llvmElemTy =
getTypeConverter()->convertType(resultPtrTy.getPointeeType());
}
smemBase = bitcast(smemBase, elemPtrTy);
auto threadId = getThreadId(rewriter, loc);
Expand All @@ -85,14 +88,16 @@ struct AllocMBarrierOpConversion : public ConvertTritonGPUOpToLLVMPattern<
for (int i = 0; i < numMBarriers; ++i) {
Value smem = smemBase;
if (i > 0) {
smem = gep(elemPtrTy, smem, i32_val(i));
smem = gep(elemPtrTy, llvmElemTy, smem, i32_val(i));
}
rewriter.create<triton::nvgpu::MBarrierInitOp>(loc, smem, pred,
op.getCount());
}
if (resultTensorTy) {
auto smemObj = SharedMemoryObject(smemBase, resultTensorTy.getShape(),
{0}, loc, rewriter);
auto llvmElemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType());
auto smemObj = SharedMemoryObject(
smemBase, llvmElemTy, resultTensorTy.getShape(), {0}, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
} else {
Expand Down Expand Up @@ -164,11 +169,11 @@ struct ExtractMBarrierOpConversion
op.getTensor().getType().cast<RankedTensorType>().getElementType();
auto tensorStruct = adaptor.getTensor();
auto index = adaptor.getIndex();
auto ptrTy =
LLVM::LLVMPointerType::get(getTypeConverter()->convertType(elemTy), 3);
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
auto basePtr =
extract_val(ptrTy, tensorStruct, rewriter.getDenseI64ArrayAttr(0));
Value result = gep(ptrTy, basePtr, index);
Value result =
gep(ptrTy, getTypeConverter()->convertType(elemTy), basePtr, index);
rewriter.replaceOp(op, result);
return success();
}
Expand Down
54 changes: 28 additions & 26 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ struct ConvertLayoutOpConversion
shapePerCTA);
Value offset = linearize(rewriter, loc, multiDimOffsetWrapped,
paddedRepShape, outOrd);
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto elemPtrTy = ptr_ty(rewriter.getContext(), 3);
Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
ptr = bitcast(ptr, ptr_ty(vecTy, 3));
ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3));
if (stNotRd) {
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
Expand All @@ -326,7 +326,7 @@ struct ConvertLayoutOpConversion
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
Value valVec = load(vecTy, ptr);
for (unsigned v = 0; v < vec; ++v) {
Value currVal = extract_element(llvmElemTy, valVec, i32_val(v));
if (isInt1)
Expand Down Expand Up @@ -421,10 +421,10 @@ struct ConvertLayoutOpConversion
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
auto coord = coord2valT[elemId].first;
Value offset = linearize(rewriter, loc, coord, paddedRepShape, outOrd);
auto elemPtrTy = ptr_ty(elemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto elemPtrTy = ptr_ty(rewriter.getContext(), 3);
Value ptr = gep(elemPtrTy, elemTy, smemBase, offset);
auto vecTy = vec_ty(elemTy, vec);
ptr = bitcast(ptr, ptr_ty(vecTy, 3));
ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3));
if (stNotRd) {
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
Expand All @@ -433,7 +433,7 @@ struct ConvertLayoutOpConversion
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
Value valVec = load(vecTy, ptr);
for (unsigned v = 0; v < vec; ++v) {
Value currVal = extract_element(elemTy, valVec, i32_val(v));
vals[elemId + v] = currVal;
Expand All @@ -460,7 +460,7 @@ struct ConvertLayoutOpConversion
unsigned rank = srcShapePerCTA.size();

auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
auto elemPtrTy = ptr_ty(rewriter.getContext(), 3);

Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
Expand All @@ -478,7 +478,7 @@ struct ConvertLayoutOpConversion

for (unsigned i = 0; i < inIndices.size(); ++i) {
Value offset = linearize(rewriter, loc, inIndices[i], smemShape);
Value ptr = gep(elemPtrTy, smemBase, offset);
Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset);
store(inVals[i], ptr);
}
}
Expand Down Expand Up @@ -511,8 +511,8 @@ struct ConvertLayoutOpConversion
linearize(rewriter, loc, multiDimCTAId, srcCTAsPerCGA, srcCTAOrder);
Value localOffset = linearize(rewriter, loc, localCoord, smemShape);

Value ptr = gep(elemPtrTy, smemBase, localOffset);
outVals.push_back(load_dsmem(ptr, remoteCTAId));
Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, localOffset);
outVals.push_back(load_dsmem(ptr, remoteCTAId, llvmElemTy));
}

Value result =
Expand Down Expand Up @@ -543,10 +543,8 @@ struct ConvertLayoutOpConversion

if (shouldUseDistSmem(srcLayout, dstLayout))
return lowerDistToDistWithDistSmem(op, adaptor, rewriter);

auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
auto elemPtrTy = ptr_ty(rewriter.getContext(), 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank();
Expand Down Expand Up @@ -705,8 +703,9 @@ struct ConvertLayoutOpConversion
auto dstLayout = dstTy.getEncoding();
auto inOrd = getOrder(srcSharedLayout);

auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
auto smemObj = getSharedMemoryObjectFromStruct(
loc, adaptor.getSrc(),
getTypeConverter()->convertType(srcTy.getElementType()), rewriter);
auto elemTy = getTypeConverter()->convertType(dstTy.getElementType());

auto srcStrides =
Expand Down Expand Up @@ -744,7 +743,7 @@ struct ConvertLayoutOpConversion
auto outOrd = dstSharedLayout.getOrder();
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
auto elemPtrTy = ptr_ty(rewriter.getContext(), 3);
smemBase = bitcast(smemBase, elemPtrTy);

int32_t elemSize = elemTy.getIntOrFloatBitWidth();
Expand All @@ -771,8 +770,7 @@ struct ConvertLayoutOpConversion
unsigned leadingDimOffset =
numElemsPerSwizzlingRow * srcShapePerCTA[outOrd[1]];

auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
typeConverter->convertType(rewriter.getI8Type()), 3);
auto ptrSharedTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);

uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0];

Expand Down Expand Up @@ -801,7 +799,8 @@ struct ConvertLayoutOpConversion
loc, i32_ty, threadId, rowOfWarp, i32_val(idx), leadingDimOffset,
numElemsPerSwizzlingRow, true);

Value addr = gep(elemPtrTy, smemBase, offset);
Value addr = gep(elemPtrTy, getTypeConverter()->convertType(elemTy),
smemBase, offset);

Value words[4];
for (unsigned i = 0; i < 8; ++i) {
Expand All @@ -812,7 +811,7 @@ struct ConvertLayoutOpConversion
}

rewriter.create<triton::nvgpu::StoreMatrixOp>(
loc, bitcast(addr, ptrI8SharedTy),
loc, bitcast(addr, ptrSharedTy),
ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty),
bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)});
}
Expand Down Expand Up @@ -841,8 +840,8 @@ struct ConvertLayoutOpConversion
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices,
dst, smemBase, elemTy, loc, rewriter);
}
auto smemObj =
SharedMemoryObject(smemBase, dstShapePerCTA, outOrd, loc, rewriter);
auto smemObj = SharedMemoryObject(smemBase, elemTy, dstShapePerCTA, outOrd,
loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
Expand Down Expand Up @@ -1011,8 +1010,11 @@ struct ConvertLayoutOpConversion
Value dst = op.getResult();
bool isMMA = supportMMA(dst, mmaLayout.getVersionMajor());

auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
auto llvmElemTy = getTypeConverter()->convertType(
src.getType().cast<RankedTensorType>().getElementType());

auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
Value res;
if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2
res = SharedToDotOperandMMAv2::convertLayout(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,

bool isARow = aOrder[0] == 1;

auto aSmem = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
auto aSmem = getSharedMemoryObjectFromStruct(
loc, llA, typeConverter->convertType(aTensorTy.getElementType()),
rewriter);
Value strideAM = aSmem.strides[0];
Value strideAK = aSmem.strides[1];
Value strideA0 = isARow ? strideAK : strideAM;
Expand Down Expand Up @@ -131,10 +133,10 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
auto elemTy = typeConverter->convertType(
A.getType().cast<RankedTensorType>().getElementType());

Type ptrTy = ptr_ty(elemTy, 3);
Type ptrTy = ptr_ty(rewriter.getContext(), 3);
SmallVector<Value> aPtrs(aNumPtr);
for (int i = 0; i < aNumPtr; ++i)
aPtrs[i] = gep(ptrTy, aSmem.base, aOff[i]);
aPtrs[i] = gep(ptrTy, elemTy, aSmem.base, aOff[i]);

SmallVector<Value> vas;

Expand All @@ -146,8 +148,8 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
for (unsigned mm = 0; mm < mSizePerThread; ++mm) {
Value offset =
add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK));
Value pa = gep(ptrTy, aPtrs[0], offset);
Value va = load(pa);
Value pa = gep(ptrTy, elemTy, aPtrs[0], offset);
Value va = load(elemTy, pa);
vas.emplace_back(va);
}

Expand All @@ -166,7 +168,9 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,

bool isBRow = bOrder[0] == 1;

auto bSmem = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
auto bSmem = getSharedMemoryObjectFromStruct(
loc, llB, typeConverter->convertType(bTensorTy.getElementType()),
rewriter);
Value strideBN = bSmem.strides[1];
Value strideBK = bSmem.strides[0];
Value strideB0 = isBRow ? strideBN : strideBK;
Expand Down Expand Up @@ -196,10 +200,10 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
auto elemTy = typeConverter->convertType(
B.getType().cast<RankedTensorType>().getElementType());

Type ptrTy = ptr_ty(elemTy, 3);
Type ptrTy = ptr_ty(rewriter.getContext(), 3);
SmallVector<Value> bPtrs(bNumPtr);
for (int i = 0; i < bNumPtr; ++i)
bPtrs[i] = gep(ptrTy, bSmem.base, bOff[i]);
bPtrs[i] = gep(ptrTy, elemTy, bSmem.base, bOff[i]);

SmallVector<Value> vbs;

Expand All @@ -211,8 +215,8 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
Value offset =
add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK));
Value pb = gep(ptrTy, bPtrs[0], offset);
Value vb = load(pb);
Value pb = gep(ptrTy, elemTy, bPtrs[0], offset);
Value vb = load(elemTy, pb);
vbs.emplace_back(vb);
}

Expand Down
Loading