Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,37 @@ def TTNG_AsyncCopyMbarrierArriveOp : TTNG_Op<"async_copy_mbarrier_arrive", [
}


def TTNG_AsyncSharedStoreOp : TTNG_Op<"async_shared_store", [
DeclareOpInterfaceMethods<MBarrierOpInterface, ["getBarrier"]>]> {
let summary = "store a distributed tensor to shared memory asynchronously";

let description = [{
Store a distributed tensor into shared memory using PTX st.async.shared.
The store completion decrements the transaction count of `mbarrier`.
This op requires a CTA cluster with at least two CTAs.
}];

let arguments = (ins
TT_Tensor:$src,
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$dst,
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$mbarrier
);

let assemblyFormat = [{
$src `,` $dst `,` $mbarrier attr-dict `:` type($src) `->`
qualified(type($dst)) `,` qualified(type($mbarrier))
}];
let hasVerifier = 1;

let extraClassDeclaration = [{
static bool isSupported(int computeCapability) {
return ::mlir::triton::nvidia_gpu::TargetFeatures(computeCapability)
.supportClusterOps();
}
}];
}


def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [
AttrSizedOperandSegments, DeclareOpInterfaceMethods<MBarrierOpInterface>,
DeclareOpInterfaceMethods<PredicatedOpInterface>, TMALoadLikeOpInterface]> {
Expand Down
40 changes: 40 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
Expand All @@ -39,6 +40,7 @@
#include "triton/Tools/StrUtil.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir::triton::gpu;
Expand Down Expand Up @@ -257,6 +259,44 @@ Type ArriveBarrierOp::getPredicateOperandTypeLike() {
return IntegerType::get(getContext(), 1);
}

// -- AsyncSharedStoreOp --
LogicalResult AsyncSharedStoreOp::verify() {
// PTX defines weak shared::cluster st.async as UB for a one-CTA cluster.
if (gpu::lookupNumCTAs(getOperation()) < 2)
return emitOpError("requires at least two CTAs in the cluster");
Comment thread
ThomasRaoux marked this conversation as resolved.
if (!getDst().getType().getMutableMemory())
return emitOpError("cannot store into immutable memory");
if (failed(triton::gpu::verifyMemoryOpTypes(*this, getSrc().getType(),
getDst().getType())))
return failure();
if (failed(verifyBarrierType(*this, getMbarrier().getType())))
return failure();

auto srcTy = getSrc().getType();
auto dstTy = getDst().getType();
unsigned bitwidth = getIntOrFloatOrPtrBitWidth(srcTy.getElementType());

auto regLayout = toLinearLayout(srcTy);
auto sharedLayout = isPaddedEncoding(dstTy.getEncoding())
? paddedLinearLayout(dstTy)
: toLinearLayout(dstTy);
auto cvt = regLayout.invertAndCompose(sharedLayout);
std::optional<int> maybeMaxVecElems;
if (isPaddedEncoding(dstTy.getEncoding()))
maybeMaxVecElems = getMinInterval(dstTy.getEncoding());
auto vectorization =
largestVectorisation(getContext(), cvt, bitwidth, maybeMaxVecElems);
unsigned elemsPerVec = vectorization.first;
if (elemsPerVec * bitwidth < 32)
return emitOpError("requires a layout vectorizing stores to at least 32 "
"bits");
return success();
}

TypedValue<MemDescType> AsyncSharedStoreOp::getBarrier() {
return getMbarrier();
}

// -- FenceMBarrierInitReleaseClusterOp --
LogicalResult FenceMBarrierInitReleaseClusterOp::verify() {
int numCTAs = triton::gpu::lookupNumCTAs(getOperation());
Expand Down
11 changes: 11 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,17 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write,
loadOp.getResult());
}
if (auto storeOp = dyn_cast<ttng::AsyncSharedStoreOp>(op)) {
info.emplace();
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
info->barriers.push_back(
{storeOp.getBarrier(), nullptr, /*count=*/0,
MemEffectsOpInfo::BarrierTrackingMode::EffectWrites,
/*txCount=*/
-static_cast<int>(tti::getMemDescLength(storeOp.getDst()))});
info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write,
storeOp.getDst());
}
if (auto tryCancelOp = dyn_cast<ttng::CLCTryCancelOp>(op)) {
info.emplace();
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
Expand Down
5 changes: 5 additions & 0 deletions python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,11 @@ void init_gluon_ir(py::module &&m) {
[](GluonOpBuilder &self, Value memDesc, Value value) {
self.create<ttg::LocalStoreOp>(value, memDesc);
})
.def(
"create_async_shared_store",
[](GluonOpBuilder &self, Value memDesc, Value value, Value mbarrier) {
self.create<ttng::AsyncSharedStoreOp>(value, memDesc, mbarrier);
})
.def("create_local_load",
[](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value {
return self.create<ttg::LocalLoadOp>(resultTy, memDesc);
Expand Down
41 changes: 41 additions & 0 deletions python/test/gluon/test_consan.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,47 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr):
kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas)


@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
@pytest.mark.parametrize("EXPECT_DELTA", [0, 4], ids=["match", "mismatch"])
def test_async_shared_store_expect_bytes(EXPECT_DELTA, device, run_wrapper, monkeypatch, num_ctas):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a very similar test for TMA. Can you see if it's possible to merge them?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a way to cleanly merge those

if num_ctas == 1:
pytest.skip("st.async.shared requires at least 2 CTAs")
if run_wrapper:
result = run_in_process(test_async_shared_store_expect_bytes,
(EXPECT_DELTA, device, False, monkeypatch, num_ctas))
if EXPECT_DELTA:
assert_expected_cuda_failure(result.exc)
assert "Deadlock detected" in result.driver_stderr_output
else:
assert result.exc is None
assert result.driver_stderr_output == ""
return

monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
knobs.refresh_knobs()

@gluon.jit
def kernel(out, EXPECT_DELTA: ttgl.constexpr):
cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 1)
layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0], cga_layout=cga_layout)
smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=cga_layout)
offsets = ttgl.arange(0, XBLOCK, layout=layout)
values = offsets.to(ttgl.int32)
smem = ttgl.allocate_shared_memory(ttgl.int32, [XBLOCK], smem_layout)
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, smem.nbytes_per_cta + EXPECT_DELTA)
hopper.async_store(smem, values, bar)
mbarrier.wait(bar, 0, deps=[smem])
result = smem.load(layout)
mbarrier.invalidate(bar)
ttgl.store(out + offsets, result)

output = torch.empty((XBLOCK.value, ), device=device, dtype=torch.int32)
kernel[(1, )](output, EXPECT_DELTA=EXPECT_DELTA, num_warps=4, num_ctas=num_ctas)


@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
@pytest.mark.parametrize("FAILURE", [True, False])
def test_async_tma_multicast_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas):
Expand Down
58 changes: 58 additions & 0 deletions python/test/gluon/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,64 @@ def test_local_store_transposed_cga_to_non_transposed_alloc():
torch.testing.assert_close(out, inp.T, atol=0, rtol=0)


@gluon.jit
def async_shared_store_kernel(out, BLOCK: ttgl.constexpr):
layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0], cga_layout=[[0]])
shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=[[0]])

offsets = ttgl.arange(0, BLOCK, layout=layout)
values = offsets.to(ttgl.int32)
smem = ttgl.allocate_shared_memory(ttgl.int32, [BLOCK], shared_layout)
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, smem.nbytes_per_cta)
hopper.async_store(smem, values, bar)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if there are any lifetime issues with the registers, similar to wgmma, or does the instruction completely finish reading the registers synchronously (via the usual SASS register dependency tracking)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there isn't lifetime issues for the register in this case, it is fully handled by the scoreboard

mbarrier.wait(bar, phase=0, deps=[smem])
result = smem.load(layout)
mbarrier.invalidate(bar)
ttgl.store(out + offsets, result)


@gluon.jit
def async_shared_store_f16_kernel(out, BLOCK: ttgl.constexpr):
layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0], cga_layout=[[0]])
shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=[[0]])

offsets = ttgl.arange(0, BLOCK, layout=layout)
values = offsets.to(ttgl.float16)
smem = ttgl.allocate_shared_memory(ttgl.float16, [BLOCK], shared_layout)
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, smem.nbytes_per_cta)
hopper.async_store(smem, values, bar)
mbarrier.wait(bar, phase=0, deps=[smem])
result = smem.load(layout)
mbarrier.invalidate(bar)
ttgl.store(out + offsets, result)


@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper")
def test_async_shared_store():
block = 128
out = torch.empty((block, ), device="cuda", dtype=torch.int32)

compiled = async_shared_store_kernel[(1, )](out, block, num_warps=4, num_ctas=2)

assert "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes" in compiled.asm["ptx"]
torch.testing.assert_close(out, torch.arange(block, device="cuda", dtype=torch.int32))


@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper")
def test_async_shared_store_packed_f16():
block = 256
out = torch.empty((block, ), device="cuda", dtype=torch.float16)

compiled = async_shared_store_f16_kernel[(1, )](out, block, num_warps=4, num_ctas=2)

assert "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32" in compiled.asm["ptx"]
torch.testing.assert_close(out, torch.arange(block, device="cuda", dtype=torch.float16))


@gluon.jit
def tma_kernel(desc):
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0])
Expand Down
16 changes: 16 additions & 0 deletions python/test/gluon/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,22 @@ def test_mbarrier(target):
""")


@gluon.jit
def async_shared_store_kernel():
layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0], cga_layout=[[0]])
shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=[[0]])
values = ttgl.arange(0, 128, layout=layout).to(ttgl.int32)
dst = ttgl.allocate_shared_memory(ttgl.int32, [128], shared_layout)
bar = mbarrier.allocate_mbarrier()
hopper.async_store(dst, values, bar)


@pytest.mark.parametrize("target", [HOPPER_TARGET, BLACKWELL_TARGET])
def test_async_shared_store(target):
mod = run_parser(async_shared_store_kernel, *make_args(num_ctas=2), target=target)
assert "ttng.async_shared_store" in anonymize_ir(mod.str_nodebug())


@gluon.jit
def tcgen05_mma_kernel(nvmma_layout: ttgl.constexpr, acc_layout: ttgl.constexpr):
a = ttgl.allocate_shared_memory(ttgl.float16, [128, 128], nvmma_layout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from . import tma
from . import clc
from ..hopper import fence_async_shared, mbarrier
from ..hopper import async_store, fence_async_shared, mbarrier
from ..ampere import async_copy, mma_v2

from triton._C.libtriton import ir
Expand All @@ -21,6 +21,7 @@
__all__ = [
"allocate_tensor_memory",
"async_copy",
"async_store",
"clc",
"fence_async_shared",
"mbarrier",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

__all__ = [
"async_copy",
"async_store",
"cluster",
"fence_async_shared",
"mbarrier",
Expand All @@ -20,6 +21,11 @@
]


def _check(cond, msg_fn, category=ValueError):
if not cond:
raise category(msg_fn())


@_core.builtin
def fence_async_shared(cluster=False, _semantic=None):
"""
Expand All @@ -32,6 +38,25 @@ def fence_async_shared(cluster=False, _semantic=None):
_semantic.builder.create_fence_async_shared(cluster)


@_core.builtin
def async_store(dst, value, mbarrier, _semantic=None):
"""
Store a tensor to shared memory asynchronously and signal an mbarrier on completion.
Requires a CTA cluster with at least two CTAs.

Args:
dst (shared_memory_descriptor): Destination shared memory descriptor.
value (tensor): Tensor whose contents to store.
mbarrier (shared_memory_descriptor): Barrier signaled when the store completes.
"""
_check(isinstance(value, _core.tensor), lambda: f"expected 'value' to be a tensor, but got a {type(value)}")
_check(isinstance(mbarrier, _core.shared_memory_descriptor),
lambda: f"expected 'mbarrier' to be a shared_memory_descriptor, but got a {type(mbarrier)}")
_check(value.shape == dst.shape, lambda: f"source shape {value.shape} and destination shape {dst.shape} must match")
_check(value.dtype == dst.dtype, lambda: f"source dtype {value.dtype} and destination dtype {dst.dtype} must match")
_semantic.builder.create_async_shared_store(dst.handle, value.handle, mbarrier.handle)


class warpgroup_mma_accumulator_type(_core.base_type):
tensor_type: _core.dtype

Expand Down
53 changes: 53 additions & 0 deletions test/Conversion/tritonnvidiagpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,59 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

// -----

#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}>
#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}>
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: async_shared_store
// CHECK: nvvm.mapa
// CHECK: nvvm.mapa
// CHECK: st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32
tt.func @async_shared_store(%src: tensor<128xi32, #blocked>, %dst: !ttg.memdesc<128xi32, #shared1, #smem, mutable>, %mbarrier: !ttg.memdesc<2xi64, #shared0, #smem, mutable>) {
ttng.async_shared_store %src, %dst, %mbarrier : tensor<128xi32, #blocked> -> !ttg.memdesc<128xi32, #shared1, #smem, mutable>, !ttg.memdesc<2xi64, #shared0, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}>
#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.two-ctas" = true} {
// CHECK-LABEL: async_shared_store_two_cta_barrier
// CHECK: nvvm.mapa
// CHECK: nvvm.mapa
// CHECK: llvm.ptrtoint
// CHECK: llvm.and
// CHECK: llvm.inttoptr
// CHECK: st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32
tt.func @async_shared_store_two_cta_barrier(%src: tensor<128xi32, #blocked>, %dst: !ttg.memdesc<128xi32, #shared1, #smem, mutable>, %mbarrier: !ttg.memdesc<1xi64, #shared0, #smem, mutable>) {
ttng.async_shared_store %src, %dst, %mbarrier : tensor<128xi32, #blocked> -> !ttg.memdesc<128xi32, #shared1, #smem, mutable>, !ttg.memdesc<1xi64, #shared0, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0]]}>
#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}>
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: async_shared_store_f16
// CHECK: llvm.bitcast {{.*}} : vector<2xf16> to i32
// CHECK: st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32
tt.func @async_shared_store_f16(%src: tensor<256xf16, #blocked>, %dst: !ttg.memdesc<256xf16, #shared1, #smem, mutable>, %mbarrier: !ttg.memdesc<2xi64, #shared0, #smem, mutable>) {
ttng.async_shared_store %src, %dst, %mbarrier : tensor<256xf16, #blocked> -> !ttg.memdesc<256xf16, #shared1, #smem, mutable>, !ttg.memdesc<2xi64, #shared0, #smem, mutable>
tt.return
}
}

// -----

// TMA copy with barrier mask zero: barrier has no CGALayout -> shared::cta
#shared0_cta = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8}>
Expand Down
Loading
Loading