Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions python/test/gluon/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1874,6 +1874,44 @@ def test_atomic_rmw():
ttgl.atomic_add(offset + ptr, val, mask=scalar_mask, sem="acquire", scope="cta")


@filecheck_test
@gluon.jit
def test_atomic_rmw_scalar_masks():
# CHECK-LABEL: test_atomic_rmw_scalar_masks
BLOCK: ttgl.constexpr = 128
x = ttgl.full([BLOCK], 0, ttgl.int64)
ptr = x.cast(ttgl.pointer_type(ttgl.int32), bitcast=True)
offs = ttgl.arange(0, BLOCK)
ptrs = ptr + offs
val = ttgl.full([BLOCK], 1, ttgl.int32)
mask = offs >= 0
scalar_mask = True
constexpr_value: ttgl.constexpr = 1
constexpr_mask: ttgl.constexpr = True

# CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu
ttgl.atomic_add(ptrs, val, mask=mask)
# CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu
ttgl.atomic_add(ptrs, 1, mask=True)
# CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu
ttgl.atomic_add(ptrs, constexpr_value, mask=constexpr_mask)
# CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu
ttgl.atomic_add(ptrs, val, mask=scalar_mask)

# CHECK: {{.*}} = tt.atomic_rmw exch, acq_rel, gpu
ttgl.atomic_xchg(ptrs, 1, mask=True)
# CHECK: {{.*}} = tt.atomic_rmw max, acq_rel, gpu
ttgl.atomic_max(ptrs, 1, mask=True)
# CHECK: {{.*}} = tt.atomic_rmw min, acq_rel, gpu
ttgl.atomic_min(ptrs, 1, mask=True)
# CHECK: {{.*}} = tt.atomic_rmw and, acq_rel, gpu
ttgl.atomic_and(ptrs, 1, mask=True)
# CHECK: {{.*}} = tt.atomic_rmw or, acq_rel, gpu
ttgl.atomic_or(ptrs, 1, mask=True)
# CHECK: {{.*}} = tt.atomic_rmw xor, acq_rel, gpu
ttgl.atomic_xor(ptrs, 1, mask=True)


@filecheck_test
@gluon.jit
def test_atomic_cas():
Expand Down
37 changes: 37 additions & 0 deletions python/test/unit/language/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,43 @@ def test():
run_parser(test)


@filecheck_test
@triton.jit
def test_atomic_scalar_masks():
# CHECK-LABEL: test_atomic_scalar_masks
BLOCK: tl.constexpr = 128
ptr = tl.full((BLOCK, ), 0, tl.int64).to(tl.pointer_type(tl.int32), bitcast=True)
offs = tl.arange(0, BLOCK)
ptrs = ptr + offs
val = tl.full((BLOCK, ), 1, tl.int32)
mask = offs >= 0
scalar_mask = True
constexpr_value: tl.constexpr = 1
constexpr_mask: tl.constexpr = True

# CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu
tl.atomic_add(ptrs, val, mask=mask)
# CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu
tl.atomic_add(ptrs, 1, mask=True)
# CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu
tl.atomic_add(ptrs, constexpr_value, mask=constexpr_mask)
# CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu
tl.atomic_add(ptrs, val, mask=scalar_mask)

# CHECK: {{.*}} = tt.atomic_rmw exch, acq_rel, gpu
tl.atomic_xchg(ptrs, 1, mask=True)
# CHECK: {{.*}} = tt.atomic_rmw max, acq_rel, gpu
tl.atomic_max(ptrs, 1, mask=True)
# CHECK: {{.*}} = tt.atomic_rmw min, acq_rel, gpu
tl.atomic_min(ptrs, 1, mask=True)
# CHECK: {{.*}} = tt.atomic_rmw and, acq_rel, gpu
tl.atomic_and(ptrs, 1, mask=True)
# CHECK: {{.*}} = tt.atomic_rmw or, acq_rel, gpu
tl.atomic_or(ptrs, 1, mask=True)
# CHECK: {{.*}} = tt.atomic_rmw xor, acq_rel, gpu
tl.atomic_xor(ptrs, 1, mask=True)


@pytest.mark.interpreter
def test_return_promotion():

Expand Down
6 changes: 3 additions & 3 deletions python/test/unit/tools/test_slice_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest

from triton.tools.triton_to_gluon_translator.ordered_set import ordered_set
from triton.tools.triton_to_gluon_translator.slice_kernel import get_reference, slice_kernel
from triton.tools.triton_to_gluon_translator.slice_kernel import RewriteSpec, get_reference, slice_kernel
from triton.tools.triton_to_gluon_translator.stable_toposort import stable_toposort


Expand Down Expand Up @@ -154,15 +154,15 @@ def matcher(context, cur_module, decorator):
top = slice_kernel(
[f"{mod('kernel_mod')}:kernel_top"],
["triton", "torch"],
ignored_decorator_matchers=[matcher],
rewrite_spec=RewriteSpec(ignored_decorator_matchers=[matcher]),
)
assert "@keep()" in top
assert "@mock_kernel" not in top

bottom = slice_kernel(
[f"{mod('kernel_mod')}:kernel_bottom"],
["triton", "torch"],
ignored_decorator_matchers=[matcher],
rewrite_spec=RewriteSpec(ignored_decorator_matchers=[matcher]),
)
assert "@keep()" not in bottom
assert "@mock_kernel" not in bottom
Expand Down
3 changes: 3 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,7 @@ def _broadcast_ptr_val_mask(self, ptr, val, mask):
if mask is None:
ptr, val = self.broadcast_tensors(ptr, val)
else:
mask = self.to_tensor(mask)
ptr, val, mask = self.broadcast_tensors(ptr, val, mask)
if ptr_shape != ptr.shape:
raise ValueError(f"Expected pointer argument to have shape {ptr.shape} but got {ptr_shape}")
Expand Down Expand Up @@ -1266,6 +1267,8 @@ def atom_red_typechecking_impl(self, ptr: TensorTy, val: TensorTy, mask: TensorT
mask_ty = ptr.type.with_element_ty(tl.int1)
mask_ir = self.builder.create_splat(mask_ty.to_ir(self.builder), mask_ir)
mask = self.tensor(mask_ir, mask_ty)
elif not mask.type.scalar.is_bool():
raise ValueError("Mask must have boolean scalar type")
return ptr, val, mask

def _signbit(self, x: TensorTy) -> TensorTy:
Expand Down
56 changes: 38 additions & 18 deletions python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,22 +420,8 @@ def tl_load_tensor_descriptor(desc, offsets):
return out


# ---- NVIDIA obj dispatch ----


@gluon.jit
def tl_obj_store(obj, offsets, value):
tl_store_tensor_descriptor(obj, offsets, value)


@gluon.jit
def tl_obj_load(obj, offsets):
return tl_load_tensor_descriptor(obj, offsets)


@gluon.jit
def tl_obj_gather(obj, x_offsets, y_offset):
desc = obj
def tl_gather_tensor_descriptor(desc, x_offsets, y_offset):
desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]]
alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout)
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
Expand All @@ -445,7 +431,7 @@ def tl_obj_gather(obj, x_offsets, y_offset):
ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]),
)
x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout)
mbarrier.expect(bar, x_offsets.shape[0] * obj.block_type.nbytes)
mbarrier.expect(bar, x_offsets.shape[0] * desc.block_type.nbytes)
tma_blackwell.async_gather(desc, x_offsets, y_offset, bar, alloc)
mbarrier.wait(bar, phase=0)
mbarrier.invalidate(bar)
Expand All @@ -455,8 +441,7 @@ def tl_obj_gather(obj, x_offsets, y_offset):


@gluon.jit
def tl_obj_scatter(obj, value, x_offsets, y_offset):
desc = obj
def tl_scatter_tensor_descriptor(desc, value, x_offsets, y_offset):
desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]]
alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout, value)
fence_async_shared()
Expand All @@ -469,6 +454,41 @@ def tl_obj_scatter(obj, value, x_offsets, y_offset):
tma.store_wait(0)


# ---- NVIDIA obj dispatch ----


@gluon.jit
def tl_obj_store(obj, offsets, value):
if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor):
return tl_store_tensor_descriptor(obj, offsets, value)
else:
return obj.store(offsets, value)


@gluon.jit
def tl_obj_load(obj, offsets):
if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor):
return tl_load_tensor_descriptor(obj, offsets)
else:
return obj.load(offsets)


@gluon.jit
def tl_obj_gather(obj, x_offsets, y_offset):
if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor):
return tl_gather_tensor_descriptor(obj, x_offsets, y_offset)
else:
return obj.gather(x_offsets, y_offset)


@gluon.jit
def tl_obj_scatter(obj, value, x_offsets, y_offset):
if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor):
return tl_scatter_tensor_descriptor(obj, value, x_offsets, y_offset)
else:
return obj.scatter(value, x_offsets, y_offset)


# ---- NVIDIA host-side descriptor ----


Expand Down
Loading