Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0e39532
[Tools][Translator] Add AMD backend support for Triton-to-Gluon trans…
jammm Mar 19, 2026
e4bb60f
Remove unnecessary null checks in getTensorDescMetadata
jammm Mar 24, 2026
e0d2031
pre-commit run
jammm Mar 24, 2026
cee7ac8
Fix NVIDIA translated dot test failure
jammm Mar 24, 2026
f11be11
Fixes with review
jammm Mar 24, 2026
99b3652
Remove redundant descriptor test and combine with existing one
jammm Mar 24, 2026
cd9dbd9
[Tools][Translator] Use @_aggregate for AMD tensor descriptors instea…
jammm Mar 25, 2026
4e02793
fix for python 3.10
jammm Apr 7, 2026
63e0cc1
[Tools][Translator] Remove _create_tdm_descriptor builtin
jammm Apr 7, 2026
a8d2fd1
review fixes
jammm Apr 8, 2026
b2f9c16
Remove else in tl_dot for NVIDIA path
jammm Apr 8, 2026
17e6c9e
Revert "Remove else in tl_dot for NVIDIA path"
jammm Apr 8, 2026
5197b9b
Use standard current_target() instead of custom _current_target override
jammm Apr 8, 2026
f2708dd
Use isinstance dispatch instead of translator routing for AMD descrip…
jammm Apr 8, 2026
cc3823e
Rename _load/_store/_gather/_scatter_tdm to tl_obj_*_amd naming conve…
jammm Apr 8, 2026
ee0db99
pre-commit run
jammm Apr 8, 2026
00964d9
Add static_assert for gfx1250 target in tl_make_tensor_descriptor_amd
jammm Apr 8, 2026
cf02165
Use current target instead of parametrizing all targets in translator…
jammm Apr 10, 2026
82e7eb2
Merge branch 'main' of https://github.com/triton-lang/triton into jam…
jammm Apr 11, 2026
28095a7
Only allow blackwell/cdna4/gfx1250 for test_simple_matmul
jammm Apr 11, 2026
0aceafa
Merge branch 'main' of https://github.com/triton-lang/triton into jam…
jammm Apr 14, 2026
8a42e7f
Introduce TranslatorTarget StrEnum for translator hardware abstraction
jammm Apr 14, 2026
3cb9ee6
Fix TranslatorTarget for Python 3.10 compat
jammm Apr 14, 2026
d4509de
Merge branch 'main' of https://github.com/triton-lang/triton into jam…
jammm Apr 15, 2026
948972d
Split translator_helpers into target-specific helper modules
jammm Apr 15, 2026
2d05693
pre-commit run
jammm Apr 15, 2026
6c2eb1f
Split translator_helpers into target-specific helper modules
jammm Apr 15, 2026
96f2074
Remove isinstance dispatch from NVIDIA obj helpers
jammm Apr 15, 2026
b1e0599
Merge branch 'main' of https://github.com/triton-lang/triton into jam…
jammm Apr 15, 2026
1e47994
Remove translator_helpers.py backward-compat shim
jammm Apr 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/test/unit/tools/test_slice_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def test_slice_kernel_public_imports():
from triton.tools.triton_to_gluon_translator.slice_kernel import slice_kernel as new_slice_kernel
from triton.tools.triton_to_gluon_translator.translator import translate_paths
from triton.tools.triton_to_gluon_translator.translator import convert_triton_to_gluon
from triton.tools.triton_to_gluon_translator.translator_helpers import convert_host_descriptor
from triton.tools.triton_to_gluon_translator.nvidia_helpers import convert_host_descriptor

assert callable(new_slice_kernel)
assert callable(translate_paths)
Expand Down
108 changes: 90 additions & 18 deletions python/test/unit/tools/test_triton_to_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,26 @@
from triton.tools.tensor_descriptor import TensorDescriptor

from triton.tools.triton_to_gluon_translator.translator import convert_triton_to_gluon
from triton.tools.triton_to_gluon_translator.translator_helpers import convert_host_descriptor
from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda
from triton.tools.triton_to_gluon_translator.target import TranslatorTarget
from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda, is_hip_cdna4, is_hip_gfx1250, is_hip_cdna3_or_newer
from triton.language.target_info import current_target


def convert_kernel(kernel, kernel_name, tmp_path):
converted = convert_triton_to_gluon([kernel])
def _convert_host_descriptor(desc):
"""Import and call the target-appropriate convert_host_descriptor."""
target = current_target()
if target is not None and target.backend == "hip":
from triton.tools.triton_to_gluon_translator.amd_helpers import convert_host_descriptor
else:
from triton.tools.triton_to_gluon_translator.nvidia_helpers import convert_host_descriptor
return convert_host_descriptor(desc)


def convert_kernel(kernel, kernel_name, tmp_path, target=None):
if target is None:
t = current_target()
target = TranslatorTarget("nvidia" if t.backend == "cuda" else t.arch)
converted = convert_triton_to_gluon([kernel], target=target)

# Write converted kernel to a file so @gluon.jit can retrieve source
mod_path = tmp_path / "converted_kernel.py"
Expand All @@ -36,7 +50,6 @@ def add_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr):
tl.store(out_ptr + offsets, x + y)


@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
def test_simple_kernel(tmp_path):
kernel = convert_kernel(add_kernel, "add_kernel", tmp_path)

Expand Down Expand Up @@ -70,9 +83,9 @@ def matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M: tl.constexpr, BLOCK_N: tl.c
impl_matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M, BLOCK_N, BLOCK_K)


@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_triton_to_gluon_dot_minimal(tmp_path):
# Convert directly from the Triton kernel object
if not (is_blackwell() or is_hip_cdna3_or_newer() or is_hip_gfx1250()):
pytest.skip("Requires Blackwell, CDNA3+, or gfx1250")
kernel = convert_kernel(matmul_tile_kernel, "matmul_tile_kernel", tmp_path)
M, N, K = 128, 128, 128
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
Expand Down Expand Up @@ -131,8 +144,9 @@ def matmul_kernel( #
@pytest.mark.parametrize("dtype_dst_str", ["float32"])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES", [(128, 128, 64, 1)])
@pytest.mark.parametrize("NUM_WARPS", [4])
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, tmp_path):
if not (is_blackwell() or is_hip_cdna4() or is_hip_gfx1250()):
pytest.skip("Requires Blackwell, CDNA4, or gfx1250")
device = "cuda"
M, N, K = 1024, 512, 256
torch.manual_seed(42)
Expand Down Expand Up @@ -161,16 +175,23 @@ def descriptor_store_kernel(desc, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
desc.store([0, 0], tile)


@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
def _skip_unless_descriptor_target():
if is_cuda() and not is_hopper_or_newer():
pytest.skip("Requires Hopper+")
elif not is_cuda() and not is_hip_gfx1250():
pytest.skip("Requires descriptor support")


def test_triton_to_gluon_descriptor_roundtrip(tmp_path):
_skip_unless_descriptor_target()
kernel = convert_kernel(descriptor_store_kernel, "descriptor_store_kernel", tmp_path)

M = N = 64
y = torch.zeros((M, N), device="cuda", dtype=torch.float16)
grid = (1, )
block_shape = [M, N]
desc = TensorDescriptor(y, y.shape, y.stride(), block_shape)
gluon_desc = convert_host_descriptor(desc)
gluon_desc = _convert_host_descriptor(desc)
kernel[grid](gluon_desc, M, N, 1.0)

y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16)
Expand All @@ -185,8 +206,8 @@ def descriptor_copy_kernel(in_desc, out_desc, BLOCK_M: tl.constexpr, BLOCK_N: tl
out_desc.store([0, 0], tile)


@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path):
_skip_unless_descriptor_target()
kernel = convert_kernel(descriptor_copy_kernel, "descriptor_copy_kernel", tmp_path)

M = N = 64
Expand All @@ -196,8 +217,8 @@ def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path):
block_shape = [M, N]

in_desc = TensorDescriptor(x, x.shape, x.stride(), block_shape)
gluon_desc = convert_host_descriptor(in_desc)
out_desc = convert_host_descriptor(TensorDescriptor(y, y.shape, y.stride(), block_shape))
gluon_desc = _convert_host_descriptor(in_desc)
out_desc = _convert_host_descriptor(TensorDescriptor(y, y.shape, y.stride(), block_shape))
kernel[grid](gluon_desc, out_desc, M, N)

y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16)
Expand All @@ -224,8 +245,8 @@ def make_tensor_descriptor_copy_kernel(x_ptr, y_ptr, M, N, BLOCK_M: tl.constexpr
out_desc.store([0, 0], tile)


@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
def test_triton_to_gluon_make_tensor_descriptor(tmp_path, with_allocator):
_skip_unless_descriptor_target()
kernel = convert_kernel(make_tensor_descriptor_copy_kernel, "make_tensor_descriptor_copy_kernel", tmp_path)

M = N = 64
Expand Down Expand Up @@ -258,7 +279,6 @@ def reshape_trans_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr,


@pytest.mark.parametrize("TRANS_KIND", ["trans_method", "tl_trans_separate", "tl_trans_tuple", "tl_trans"])
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
def test_triton_reshape_trans(tmp_path, TRANS_KIND):
kernel = convert_kernel(reshape_trans_kernel, "reshape_trans_kernel", tmp_path)

Expand Down Expand Up @@ -289,7 +309,6 @@ def split_kernel(x_ptr, out_ptr):
tl.store(p, a)


@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
def test_split(tmp_path):
kernel = convert_kernel(split_kernel, "split_kernel", tmp_path)

Expand Down Expand Up @@ -339,7 +358,6 @@ def reduce_to_scalar_kernel(out_ptr):
tl.store(out_ptr, x)


@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
def test_reduce_to_scalar(tmp_path):
kernel = convert_kernel(reduce_to_scalar_kernel, "reduce_to_scalar_kernel", tmp_path)
grid = (1, )
Expand Down Expand Up @@ -403,7 +421,6 @@ def atomic_add_kernel(out_ptr, BLOCK: tl.constexpr):
tl.atomic_add(out_ptr + idx, idx, mask=scalar_mask, sem="release", scope="cta")


@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
def test_atomic_add(tmp_path):
kernel = convert_kernel(atomic_add_kernel, "atomic_add_kernel", tmp_path)

Expand All @@ -414,3 +431,58 @@ def test_atomic_add(tmp_path):
out = torch.zeros((block, ), device="cuda")
kernel[(1, )](out, BLOCK=block)
torch.testing.assert_close(out, ref)


# ---- additional op coverage ----


@triton.jit
def cat_kernel(x_ptr, y_ptr, out_ptr, BLOCK: tl.constexpr):
offs = tl.arange(0, BLOCK)
x = tl.load(x_ptr + offs)
y = tl.load(y_ptr + offs)
z = tl.cat(x, y, can_reorder=True)
tl.store(out_ptr + tl.arange(0, 2 * BLOCK), z)


def test_cat(tmp_path):
kernel = convert_kernel(cat_kernel, "cat_kernel", tmp_path)

BLOCK = 256
x = torch.randn(BLOCK, device="cuda", dtype=torch.float32)
y = torch.randn(BLOCK, device="cuda", dtype=torch.float32)
out = torch.empty(2 * BLOCK, device="cuda", dtype=torch.float32)
kernel[(1, )](x, y, out, BLOCK)

ref = torch.empty_like(out)
cat_kernel[(1, )](x, y, ref, BLOCK)
torch.testing.assert_close(sorted(out.cpu()), sorted(ref.cpu()), atol=0, rtol=0)


@triton.jit
def gather_scatter_roundtrip_kernel(out_ptr, in_ptr, idx_ptr, X: tl.constexpr, Y: tl.constexpr, BLOCK_X: tl.constexpr,
BLOCK_Y: tl.constexpr):
idx = tl.load(idx_ptr + tl.arange(0, BLOCK_X))
in_desc = tl.make_tensor_descriptor(in_ptr, [X, Y], [Y, 1], [1, BLOCK_Y])
out_desc = tl.make_tensor_descriptor(out_ptr, [X, Y], [Y, 1], [1, BLOCK_Y])
data = in_desc.gather(idx, 0)
out_desc.scatter(data, idx, 0)


# TODO: parametrize over _descriptor_targets once NVIDIA gather/scatter translation is supported.
# The translator currently routes gather/scatter to AMD-specific helpers for AMD targets only.
@pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires gfx1250")
def test_gather_scatter_roundtrip(tmp_path):
kernel = convert_kernel(gather_scatter_roundtrip_kernel, "gather_scatter_roundtrip_kernel", tmp_path,
target=TranslatorTarget.GFX1250)

X, Y, BLOCK_X, BLOCK_Y = 64, 64, 8, 64
inp = torch.arange(X * Y, device="cuda", dtype=torch.float16).reshape(X, Y)
idx = torch.tensor([0, 2, 4, 6, 1, 3, 5, 7], device="cuda", dtype=torch.int32)
out = torch.zeros((X, Y), device="cuda", dtype=torch.float16)
kernel[(1, )](out, inp, idx, X, Y, BLOCK_X, BLOCK_Y)

expected = torch.zeros_like(out)
for i, row in enumerate(idx.tolist()):
expected[row] = inp[row]
torch.testing.assert_close(out, expected, atol=0, rtol=0)
4 changes: 4 additions & 0 deletions python/triton/_internal_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def is_hip_gfx1250():
return target is not None and target.backend == 'hip' and 'gfx1250' in target.arch


def is_hip_cdna3_or_newer():
return is_hip_cdna3() or is_hip_cdna4()


def is_hip_cdna():
return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4()

Expand Down
Loading
Loading