Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def TMALoadLikeOpInterface : OpInterface<"TMALoadLikeOpInterface", [TMAOpInterfa
/*retType=*/"::mlir::Value",
/*methodName=*/"getPred",
/*args=*/(ins)>,
InterfaceMethod<
/*desc=*/"Return true if this load uses multicast",
/*retType=*/"bool",
/*methodName=*/"getMulticast",
/*args=*/(ins)>,
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ SmallVector<uint16_t> getTensorCoreBarrierBroadcastMasks(Operation *op) {
Value getBarrierRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op);

Value getMemEffectRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) {
if (auto copyOp = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
if (copyOp.getMulticast())
return getMulticastRecipientCTAs(b, copyOp.getResult());
if (auto tmaLoad = dyn_cast<ttng::TMALoadLikeOpInterface>(op)) {
if (tmaLoad.getMulticast())
return getMulticastRecipientCTAs(b, tmaLoad.getResult());
return currentCTAMask(b);
}
if (isTensorCoreOp(op))
Expand All @@ -272,14 +272,12 @@ Value getBarrierRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) {
return getLeaderCTA(b, arriveOp.getAlloc());
if (auto arriveOp = dyn_cast<ttng::AsyncCopyMbarrierArriveOp>(op))
return getLeaderCTA(b, arriveOp.getBarrier());
if (auto copyOp = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
if (copyOp.getMulticast())
return getMulticastBarrierRecipientCTAs(b, copyOp.getResult(),
copyOp.getBarrier());
return getLeaderCTA(b, copyOp.getBarrier());
}
if (auto tmaLoad = dyn_cast<ttng::TMALoadLikeOpInterface>(op))
if (auto tmaLoad = dyn_cast<ttng::TMALoadLikeOpInterface>(op)) {
if (tmaLoad.getMulticast())
return getMulticastBarrierRecipientCTAs(b, tmaLoad.getResult(),
tmaLoad.getBarrier());
return getLeaderCTA(b, tmaLoad.getBarrier());
}

if (isTensorCoreOp(op))
return getRecipientCTAsForBroadcastMasks(
Expand Down
33 changes: 17 additions & 16 deletions lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ Value getLeaderCTAPredicate(ImplicitLocOpBuilder &b, uint32_t broadcastMask) {
arith::ConstantIntOp::create(b, 0, 32));
}

uint32_t getBlockBroadcastMask(Type type) {
auto memDescTy = cast<ttg::MemDescType>(type);
auto kBlock = StringAttr::get(type.getContext(), "block");
return toLinearLayout(memDescTy).getFreeVariableMasks().lookup(kBlock);
}

} // namespace

class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
Expand Down Expand Up @@ -89,13 +95,12 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
mask = getBarrierMask(waitOp.getAlloc());
if (auto invalOp = dyn_cast<ttng::InvalBarrierOp>(op))
mask = getBarrierMask(invalOp.getAlloc());
if (auto copyOp = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
if (copyOp.getMulticast()) {
auto dstTy = cast<ttg::MemDescType>(copyOp.getResult().getType());
auto kBlock = StringAttr::get(op->getContext(), "block");
mask = toLinearLayout(dstTy).getFreeVariableMasks().lookup(kBlock);
}
if (auto loadOp = dyn_cast<ttng::TMALoadLikeOpInterface>(op)) {
if (loadOp.getMulticast())
mask = getBlockBroadcastMask(loadOp.getResult().getType());
}
if (auto storeOp = dyn_cast<ttng::TMAStoreLikeOpInterface>(op))
mask = getBlockBroadcastMask(storeOp.getSrc().getType());

// In 2CTA tcgen05 and tmem_copy, only the even CTA in each (i, i^1) pair
// issues the op.
Expand Down Expand Up @@ -205,16 +210,12 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
info->pred = loadOp.getPred();
int txCount = tti::getMemDescLength(loadOp.getResult());
if (auto copyOp = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op);
copyOp && copyOp.getMulticast()) {
auto resultTy = cast<ttg::MemDescType>(loadOp.getResult().getType());
auto barrierTy = cast<ttg::MemDescType>(loadOp.getBarrier().getType());
auto kBlock = StringAttr::get(op->getContext(), "block");
uint16_t resultMask =
toLinearLayout(resultTy).getFreeVariableMasks().lookup(kBlock);
uint16_t barrierMask =
toLinearLayout(barrierTy).getFreeVariableMasks().lookup(kBlock);
uint16_t collapsedMask = resultMask & barrierMask;
if (loadOp.getMulticast()) {
uint32_t resultMask =
getBlockBroadcastMask(loadOp.getResult().getType());
uint32_t barrierMask =
getBlockBroadcastMask(loadOp.getBarrier().getType());
uint32_t collapsedMask = resultMask & barrierMask;
for (; collapsedMask; collapsedMask &= collapsedMask - 1)
txCount *= 2;
}
Expand Down
62 changes: 61 additions & 1 deletion test/TritonGPU/consan.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,40 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar

// -----

#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32, CGALayout = [[0, 0]]}>
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
#smem = #ttg.shared_memory
#offset_parent = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[0, 0]]}>
#offsets = #ttg.slice<{dim = 0, parent = #offset_parent}>
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, "ttng.two-ctas" = true, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
// CHECK-LABEL: @multicast_gather_two_cta_tx_count
tt.func public @multicast_gather_two_cta_tx_count(%desc: !tt.tensordesc<1x32xf32, #shared>) {
%true = arith.constant true
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%x_offsets = arith.constant dense<0> : tensor<32xi32, #offsets>
%bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable>
%result = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable>
// CHECK: scf.for
scf.for %i = %c0 to %c2 step %c1 {
// CHECK: arith.constant 8192 : i64
// CHECK: tt.call @__triton_consan_verify_barrier_arrive
// CHECK: ttng.barrier_expect
ttng.barrier_expect %bar, 4096, %true : !ttg.memdesc<1xi64, #shared1, #smem, mutable>
// CHECK: arith.constant -8192 : i64
// CHECK: tt.call @__triton_consan_verify_barrier_arrive
// CHECK: ttng.async_tma_gather
ttng.async_tma_gather %desc[%x_offsets, %c0_i32] %result, %bar, %true {multicast} : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32, #offsets>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>, i1
}
tt.return
}
}

// -----

#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
#smem = #ttg.shared_memory
Expand Down Expand Up @@ -479,9 +513,35 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
ttng.warp_group_dot %shmem, %shmem, %acc : !ttg.memdesc<128x128xf16, #shared, #smem, mutable> * !ttg.memdesc<128x128xf16, #shared, #smem, mutable> -> tensor<128x128xf16, #mma>
// CHECK: ttng.warp_group_dot

// CHECK: tt.call @__triton_consan_verify_write_visibility
// CHECK: tt.call @__triton_consan_check_outstanding_commits
// CHECK: tt.call @__triton_consan_stage_access_for_commit
// CHECK: tt.call @__triton_consan_commit_accesses
ttng.async_tma_scatter %arg0[%x_offsets, %c0_i32] %0 : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32>, i32, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
tt.return
}
}

// -----

#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
#smem = #ttg.shared_memory
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
// CHECK-LABEL: @async_tma_reduce
tt.func public @async_tma_reduce(%arg0: !tt.tensordesc<32x32xf32, #shared>, %ptr: tensor<128x128x!tt.ptr<f16>, #blocked>, %acc: tensor<128x128xf16, #mma>) {
%c0_i32 = arith.constant 0 : i32
%0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
%shmem = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
ttg.async_copy_global_to_local %ptr, %shmem : tensor<128x128x!tt.ptr<f16>, #blocked> -> <128x128xf16, #shared, #smem, mutable>

// CHECK: tt.call @__triton_consan_verify_write_visibility
// CHECK: tt.call @__triton_consan_check_outstanding_commits
ttng.async_tma_scatter %arg0[%x_offsets, %c0_i32] %0 : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32>, i32, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
// CHECK: tt.call @__triton_consan_stage_access_for_commit
// CHECK: tt.call @__triton_consan_commit_accesses
ttng.async_tma_reduce add, %arg0[%c0_i32, %c0_i32] %0 : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
tt.return
}
}
Expand Down
Loading