From 0e395327a4d7a1b6efd4463673c59f1827946b34 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Thu, 19 Mar 2026 02:58:27 -0700 Subject: [PATCH 01/26] [Tools][Translator] Add AMD backend support for Triton-to-Gluon translator Add support for translating Triton kernels to Gluon for AMD targets (gfx1250, gfx942, gfx950). This includes: - Architecture detection helpers (_is_gfx1250, _is_cdna, etc.) - AMD WMMA dot path for gfx1250 and MFMA dot path for CDNA3/CDNA4 - TDM tensor descriptor support (load/store/gather/scatter) for gfx1250 - Correct warp size handling (32 for gfx1250, 64 for CDNA) - Cross-compilation via _current_target / _make_target - tl_atomic_add and convert_to_expand_dims_layout as builtins - Parametrized tests across all targets (nvidia, gfx1250, gfx942, gfx950) - Fix segfault in getTensorDescMetadata for unencoded tensor descriptors Made-with: Cursor --- python/src/ir.cc | 5 +- .../test/unit/tools/test_triton_to_gluon.py | 156 ++++++-- .../inline_helpers.py | 41 +- .../slice_kernel.py | 17 +- .../triton_to_gluon_translator/translator.py | 34 +- .../translator_helpers.py | 359 ++++++++++++++++-- 6 files changed, 530 insertions(+), 82 deletions(-) diff --git a/python/src/ir.cc b/python/src/ir.cc index db769960eb87..604d80b028ca 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -187,7 +187,7 @@ py::list getTensorDescMetadata(ModuleOp &mod) { auto encoding = blockType.getEncoding(); py::dict metadata; - if (isa(encoding)) { + if (isa_and_nonnull(encoding)) { auto mmaEncoding = dyn_cast(encoding); auto swizzle = ttng::getTMASwizzleMode(arg.getLoc(), descTy); auto elemType = ttng::getTMAElementType(arg.getLoc(), descTy); @@ -209,7 +209,8 @@ py::list getTensorDescMetadata(ModuleOp &mod) { std::vector(blockShape.begin(), blockShape.end()); metadata["elem_bits"] = blockType.getElementTypeBitWidth(); - if (auto paddedEnc = dyn_cast(encoding)) { + if (auto paddedEnc = + dyn_cast_or_null(encoding)) { py::list intervalPaddingPairs; for (auto [interval, padding] : llvm::zip_equal( paddedEnc.getIntervals(), paddedEnc.getPaddings())) { diff --git a/python/test/unit/tools/test_triton_to_gluon.py b/python/test/unit/tools/test_triton_to_gluon.py index 697da0dc5ebd..47adad8c6f1c 100644 --- a/python/test/unit/tools/test_triton_to_gluon.py +++ b/python/test/unit/tools/test_triton_to_gluon.py @@ -8,11 +8,23 @@ from triton.tools.triton_to_gluon_translator.translator import convert_triton_to_gluon from triton.tools.triton_to_gluon_translator.translator_helpers import convert_host_descriptor -from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda +from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda, is_hip_gfx1250, is_hip_cdna3, is_hip_cdna4 +_all_targets = { + "nvidia": is_cuda, + "gfx1250": is_hip_gfx1250, + "gfx942": is_hip_cdna3, + "gfx950": is_hip_cdna4, +} -def convert_kernel(kernel, kernel_name, tmp_path): - converted = convert_triton_to_gluon([kernel]) + +def _skip_unless_target(target): + if not _all_targets[target](): + pytest.skip(f"Requires {target}") + + +def convert_kernel(kernel, kernel_name, tmp_path, target="nvidia"): + converted = convert_triton_to_gluon([kernel], target=target) # Write converted kernel to a file so @gluon.jit can retrieve source mod_path = tmp_path / "converted_kernel.py" @@ -36,9 +48,10 @@ def add_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr): tl.store(out_ptr + offsets, x + y) -@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA") -def test_simple_kernel(tmp_path): - kernel = convert_kernel(add_kernel, "add_kernel", tmp_path) +@pytest.mark.parametrize("target", _all_targets.keys()) +def test_simple_kernel(tmp_path, target): + _skip_unless_target(target) + kernel = convert_kernel(add_kernel, "add_kernel", tmp_path, target=target) n = 1024 BLOCK = 128 @@ -70,10 +83,10 @@ def matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M: tl.constexpr, BLOCK_N: tl.c impl_matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M, BLOCK_N, BLOCK_K) -@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") -def test_triton_to_gluon_dot_minimal(tmp_path): - # Convert directly from the Triton kernel object - kernel = convert_kernel(matmul_tile_kernel, "matmul_tile_kernel", tmp_path) +@pytest.mark.parametrize("target", _all_targets.keys()) +def test_triton_to_gluon_dot_minimal(tmp_path, target): + _skip_unless_target(target) + kernel = convert_kernel(matmul_tile_kernel, "matmul_tile_kernel", tmp_path, target=target) M, N, K = 128, 128, 128 a = torch.randn((M, K), device="cuda", dtype=torch.float16) b = torch.randn((K, N), device="cuda", dtype=torch.float16) @@ -161,9 +174,17 @@ def descriptor_store_kernel(desc, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, desc.store([0, 0], tile) -@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer") -def test_triton_to_gluon_descriptor_roundtrip(tmp_path): - kernel = convert_kernel(descriptor_store_kernel, "descriptor_store_kernel", tmp_path) +_descriptor_targets = { + "nvidia": is_hopper_or_newer, + "gfx1250": is_hip_gfx1250, +} + + +@pytest.mark.parametrize("target", _descriptor_targets.keys()) +def test_triton_to_gluon_descriptor_roundtrip(tmp_path, target): + if not _descriptor_targets[target](): + pytest.skip(f"Requires {target} with descriptor support") + kernel = convert_kernel(descriptor_store_kernel, "descriptor_store_kernel", tmp_path, target=target) M = N = 64 y = torch.zeros((M, N), device="cuda", dtype=torch.float16) @@ -185,9 +206,11 @@ def descriptor_copy_kernel(in_desc, out_desc, BLOCK_M: tl.constexpr, BLOCK_N: tl out_desc.store([0, 0], tile) -@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer") -def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path): - kernel = convert_kernel(descriptor_copy_kernel, "descriptor_copy_kernel", tmp_path) +@pytest.mark.parametrize("target", _descriptor_targets.keys()) +def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path, target): + if not _descriptor_targets[target](): + pytest.skip(f"Requires {target} with descriptor support") + kernel = convert_kernel(descriptor_copy_kernel, "descriptor_copy_kernel", tmp_path, target=target) M = N = 64 x = torch.ones((M, N), device="cuda", dtype=torch.float16) * 3.0 @@ -258,9 +281,10 @@ def reshape_trans_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr, @pytest.mark.parametrize("TRANS_KIND", ["trans_method", "tl_trans_separate", "tl_trans_tuple", "tl_trans"]) -@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA") -def test_triton_reshape_trans(tmp_path, TRANS_KIND): - kernel = convert_kernel(reshape_trans_kernel, "reshape_trans_kernel", tmp_path) +@pytest.mark.parametrize("target", _all_targets.keys()) +def test_triton_reshape_trans(tmp_path, TRANS_KIND, target): + _skip_unless_target(target) + kernel = convert_kernel(reshape_trans_kernel, "reshape_trans_kernel", tmp_path, target=target) n = 1024 BLOCK = 256 @@ -289,9 +313,10 @@ def split_kernel(x_ptr, out_ptr): tl.store(p, a) -@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA") -def test_split(tmp_path): - kernel = convert_kernel(split_kernel, "split_kernel", tmp_path) +@pytest.mark.parametrize("target", _all_targets.keys()) +def test_split(tmp_path, target): + _skip_unless_target(target) + kernel = convert_kernel(split_kernel, "split_kernel", tmp_path, target=target) n = 1024 x = torch.randn(2 * n, device="cuda", dtype=torch.float32) @@ -339,9 +364,10 @@ def reduce_to_scalar_kernel(out_ptr): tl.store(out_ptr, x) -@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA") -def test_reduce_to_scalar(tmp_path): - kernel = convert_kernel(reduce_to_scalar_kernel, "reduce_to_scalar_kernel", tmp_path) +@pytest.mark.parametrize("target", _all_targets.keys()) +def test_reduce_to_scalar(tmp_path, target): + _skip_unless_target(target) + kernel = convert_kernel(reduce_to_scalar_kernel, "reduce_to_scalar_kernel", tmp_path, target=target) grid = (1, ) out = torch.empty((1, ), device="cuda", dtype=torch.int32) @@ -414,3 +440,83 @@ def test_atomic_add(tmp_path): out = torch.zeros((block, ), device="cuda") kernel[(1, )](out, BLOCK=block) torch.testing.assert_close(out, ref) + + +# ---- additional op coverage ---- + + +@triton.jit +def cat_kernel(x_ptr, y_ptr, out_ptr, BLOCK: tl.constexpr): + offs = tl.arange(0, BLOCK) + x = tl.load(x_ptr + offs) + y = tl.load(y_ptr + offs) + z = tl.cat(x, y, can_reorder=True) + tl.store(out_ptr + tl.arange(0, 2 * BLOCK), z) + + +@pytest.mark.parametrize("target", _all_targets.keys()) +def test_cat(tmp_path, target): + _skip_unless_target(target) + kernel = convert_kernel(cat_kernel, "cat_kernel", tmp_path, target=target) + + BLOCK = 256 + x = torch.randn(BLOCK, device="cuda", dtype=torch.float32) + y = torch.randn(BLOCK, device="cuda", dtype=torch.float32) + out = torch.empty(2 * BLOCK, device="cuda", dtype=torch.float32) + kernel[(1, )](x, y, out, BLOCK) + + ref = torch.empty_like(out) + cat_kernel[(1, )](x, y, ref, BLOCK) + torch.testing.assert_close(sorted(out.cpu()), sorted(ref.cpu()), atol=0, rtol=0) + + +@triton.jit +def make_desc_copy_kernel(in_ptr, out_ptr, M, N, stride_m, stride_n, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + in_desc = tl.make_tensor_descriptor(in_ptr, shape=[M, N], strides=[stride_m, stride_n], + block_shape=[BLOCK_M, BLOCK_N]) + out_desc = tl.make_tensor_descriptor(out_ptr, shape=[M, N], strides=[stride_m, stride_n], + block_shape=[BLOCK_M, BLOCK_N]) + tile = in_desc.load([0, 0]) + out_desc.store([0, 0], tile) + + +@pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires gfx1250") +def test_make_tensor_descriptor_gfx1250(tmp_path): + kernel = convert_kernel(make_desc_copy_kernel, "make_desc_copy_kernel", tmp_path, target="gfx1250") + + M, N = 64, 64 + x = torch.randn((M, N), device="cuda", dtype=torch.float16) + y = torch.zeros((M, N), device="cuda", dtype=torch.float16) + grid = (1, ) + kernel[grid](x, y, M, N, x.stride(0), x.stride(1), M, N) + + y_ref = torch.zeros_like(y) + make_desc_copy_kernel[grid](x, y_ref, M, N, x.stride(0), x.stride(1), M, N) + torch.testing.assert_close(y, y_ref, atol=0, rtol=0) + + +@triton.jit +def gather_scatter_roundtrip_kernel(out_ptr, in_ptr, idx_ptr, X: tl.constexpr, Y: tl.constexpr, BLOCK_X: tl.constexpr, + BLOCK_Y: tl.constexpr): + idx = tl.load(idx_ptr + tl.arange(0, BLOCK_X)) + in_desc = tl.make_tensor_descriptor(in_ptr, [X, Y], [Y, 1], [1, BLOCK_Y]) + out_desc = tl.make_tensor_descriptor(out_ptr, [X, Y], [Y, 1], [1, BLOCK_Y]) + data = in_desc.gather(idx, 0) + out_desc.scatter(data, idx, 0) + + +@pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires gfx1250") +def test_gather_scatter_roundtrip(tmp_path): + kernel = convert_kernel(gather_scatter_roundtrip_kernel, "gather_scatter_roundtrip_kernel", tmp_path, + target="gfx1250") + + X, Y, BLOCK_X, BLOCK_Y = 64, 64, 8, 64 + inp = torch.arange(X * Y, device="cuda", dtype=torch.float16).reshape(X, Y) + idx = torch.tensor([0, 2, 4, 6, 1, 3, 5, 7], device="cuda", dtype=torch.int32) + out = torch.zeros((X, Y), device="cuda", dtype=torch.float16) + kernel[(1, )](out, inp, idx, X, Y, BLOCK_X, BLOCK_Y) + + expected = torch.zeros_like(out) + for i, row in enumerate(idx.tolist()): + expected[row] = inp[row] + torch.testing.assert_close(out, expected, atol=0, rtol=0) diff --git a/python/triton/tools/triton_to_gluon_translator/inline_helpers.py b/python/triton/tools/triton_to_gluon_translator/inline_helpers.py index 93f91a8ecfef..235add27f869 100644 --- a/python/triton/tools/triton_to_gluon_translator/inline_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/inline_helpers.py @@ -1,25 +1,42 @@ +_torch_dtype_to_triton_def = R""" +def _torch_dtype_to_triton(dtype): + import torch + + if dtype == torch.float8_e5m2: + return gl.float8e5 + if dtype == torch.float8_e4m3fn: + return gl.float8e4nv + return getattr(gl, str(dtype).split(".")[1]) +""" + defs: dict[str, str] = { "convert_host_descriptor": - R""" + _torch_dtype_to_triton_def + R""" def convert_host_descriptor(desc): - def torch_dtype_to_triton(dtype): - import torch - - if dtype == torch.float8_e5m2: - return gl.float8e5 - if dtype == torch.float8_e4m3fn: - return gl.float8e4nv - return getattr(gl, str(dtype).split(".")[1]) - from triton.tools.tensor_descriptor import TensorDescriptor assert isinstance(desc, TensorDescriptor) block_shape = desc.block_shape dtype = desc.base.dtype tensor = desc.base - layout = gl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(dtype)) + layout = gl.NVMMASharedLayout.get_default_for(block_shape, _torch_dtype_to_triton(dtype)) return gluon.nvidia.hopper.TensorDescriptor( tensor, desc.shape, desc.strides, block_shape, layout ) -""" +""", + "convert_host_descriptor_amd": + _torch_dtype_to_triton_def + R""" +def convert_host_descriptor(desc): + from triton.tools.tensor_descriptor import TensorDescriptor + + assert isinstance(desc, TensorDescriptor) + block_shape = desc.block_shape + dtype = desc.base.dtype + layout = gl.PaddedSharedLayout.with_identity_for( + [[block_shape[-1], 4]], list(block_shape), [1, 0] + ) + return gluon.amd.gfx1250.TensorDescriptor( + desc.base, list(desc.shape), list(desc.strides), block_shape, layout + ) +""", } diff --git a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py index 81af3a89fe3d..337939681c9b 100644 --- a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py +++ b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py @@ -521,6 +521,7 @@ class SliceRewriter(ReferenceRewriter): translate_to_gluon: bool = False inline_helpers: ordered_set[str] = field(default_factory=ordered_set[str]) cvt_context: list[bool] = field(default_factory=lambda: [False]) + target: str = "nvidia" def __post_init__(self) -> None: # Special rules for sugaring imports. @@ -551,6 +552,9 @@ def emit_reference(self, node: ast.AST) -> Any: return node raise e + def _is_amd_target(self) -> bool: + return self.target.startswith("gfx") + def visit_Attribute(self, node: ast.Attribute) -> ast.AST: if not self.translate_to_gluon: return super().visit_Attribute(node) @@ -562,7 +566,10 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST: self.imports.add("import triton.experimental.gluon._runtime as gluon_runtime") new_node = parse_expr("gluon_runtime.GluonJITFunction") elif value is tl.tensor_descriptor: - self.imports.add("from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor") + if self._is_amd_target(): + self.imports.add("from triton.experimental.gluon.language.amd.gfx1250.tdm import tensor_descriptor") + else: + self.imports.add("from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor") new_node = ast.Name(id="tensor_descriptor", ctx=ast.Load()) return new_node @@ -709,6 +716,7 @@ def slice_kernel( leaf_paths: list[str] | None = None, translate_to_gluon: bool = False, ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, + target: str = "nvidia", ) -> str: base_values: list[GlobalValue] = [get_base_value(root_path) for root_path in root_paths] base_value_ids: set[int] = set() @@ -736,7 +744,7 @@ def filter(value: ModuleType | GlobalValue) -> bool: ignored_decorator_matchers=ignored_decorator_matchers, ) jit_functions = [fn for fn in jit_functions if not fn.original_value.is_gluon()] - converted_functions = translate_kernels(jit_functions) + converted_functions = translate_kernels(jit_functions, target=target) module_file = tempfile.NamedTemporaryFile(delete=False, prefix="translated_", suffix=".py") module_path = Path(module_file.name) module_path.write_text(converted_functions) @@ -782,6 +790,7 @@ def filter(value: ModuleType | GlobalValue) -> bool: ignored_decorator_matchers=tuple(ignored_decorator_matchers or ()), translate_to_gluon=translate_to_gluon, inline_helpers=inline_helpers, + target=target, ) tree = rewriter.visit(tree) source = ast.unparse(tree) @@ -807,6 +816,7 @@ def slice_kernel_from_trace( translate_to_gluon: bool, extra_modules: dict[str, str], ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, + target: str = "nvidia", ) -> str: module_remap: dict[str, str] = {} for name, path in extra_modules.items(): @@ -831,6 +841,7 @@ def slice_kernel_from_trace( leaf_paths=sorted(leaf_paths), translate_to_gluon=translate_to_gluon, ignored_decorator_matchers=ignored_decorator_matchers, + target=target, ) fn_name = lambda path: path.split(":")[1] @@ -853,6 +864,7 @@ def main( translate_to_gluon: bool = False, output_path: str = "/tmp/reference.py", ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, + target: str = "nvidia", ) -> None: output = slice_kernel( root_paths, @@ -861,6 +873,7 @@ def main( leaf_paths, translate_to_gluon, ignored_decorator_matchers, + target=target, ) with open(output_path, "w") as f: f.write(output) diff --git a/python/triton/tools/triton_to_gluon_translator/translator.py b/python/triton/tools/triton_to_gluon_translator/translator.py index 4d7837931879..2ef193787615 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator.py +++ b/python/triton/tools/triton_to_gluon_translator/translator.py @@ -108,6 +108,7 @@ def add_expr_rewrites(rewrites: list[RewriteFn]) -> None: @dataclass class Translator(ReferenceRewriter): tensor_member_match_fns: list[str] = field(default_factory=list) + target: str = "nvidia" def __post_init__(self) -> None: import triton @@ -160,7 +161,13 @@ def visit_Call(self, node: ast.Call) -> ast.AST: "gather", "scatter", ]: - new_callee = parse_expr(f"helpers.tl_obj_{node.func.attr}") + attr = node.func.attr + # Use AMD-specific helpers for gather/scatter on AMD targets + if attr in ["gather", "scatter"] and self.target.startswith("gfx"): + helper_name = f"tl_obj_{attr}_amd" + else: + helper_name = f"tl_obj_{attr}" + new_callee = parse_expr(f"helpers.{helper_name}") node = ast.Call(func=new_callee, args=[node.func.value] + node.args, keywords=node.keywords) return self.generic_visit(node) value, _, _ = ref @@ -198,7 +205,7 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.AST: return self.generic_visit(node) -def translate_kernels(kernels: list[GlobalValue]) -> str: +def translate_kernels(kernels: list[GlobalValue], target: str = "nvidia") -> str: def filter(value: ModuleType | GlobalValue) -> bool: if isinstance(value, ModuleType): @@ -240,22 +247,26 @@ def filter(value: ModuleType | GlobalValue) -> bool: imports, filter, value_remap={}, + target=target, ) tree = rewriter.visit(tree) source = ast.unparse(tree) assert reference.mangled_name is not None source = reference.value.mangle_source(source, reference.mangled_name) output += source + "\n\n\n" - output = "\n".join(imports) + "\n\n" + output - return output + header = "\n".join(imports) + "\n" + if target != "nvidia": + header += f'\nhelpers._current_target = helpers._make_target("{target}")\n' + header += "\n" + return header + output -def translate_paths(kernel_paths: list[str]) -> str: +def translate_paths(kernel_paths: list[str], target: str = "nvidia") -> str: kernels = [get_base_value(kernel_path) for kernel_path in kernel_paths] - return translate_kernels(kernels) + return translate_kernels(kernels, target=target) -def convert_triton_to_gluon(src: list[JITCallable]) -> str: +def convert_triton_to_gluon(src: list[JITCallable], target: str = "nvidia") -> str: kernels = [ GlobalValue.wrap( kernel, @@ -263,11 +274,11 @@ def convert_triton_to_gluon(src: list[JITCallable]) -> str: lambda: builtins, ) for kernel in src ] - return translate_kernels(kernels) + return translate_kernels(kernels, target=target) -def main(kernels: list[str], output_path: str) -> None: - output = translate_paths(kernels) +def main(kernels: list[str], output_path: str, target: str = "nvidia") -> None: + output = translate_paths(kernels, target=target) with open(output_path, "w") as f: f.write(output) @@ -276,8 +287,9 @@ def _main_cli() -> None: parser = argparse.ArgumentParser(description="Translate Triton kernels to Gluon source.") parser.add_argument("kernels", nargs="+", help="Kernel symbols in module.path:object format.") parser.add_argument("--output-path", required=True, help="Path to write the translated source.") + parser.add_argument("--target", default="nvidia", help="Target architecture (e.g. nvidia, amd_gfx1250).") args = parser.parse_args() - main(args.kernels, args.output_path) + main(args.kernels, args.output_path, target=args.target) if __name__ == "__main__": diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index bdf94145f654..da69dc5aab44 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -1,5 +1,6 @@ # type: ignore +import math from typing import Any from triton.experimental import gluon @@ -15,10 +16,9 @@ ) from triton.experimental.gluon.language.nvidia.blackwell import tma as tma_blackwell from triton.experimental.gluon.language.nvidia.hopper import fence_async_shared, mbarrier, tma - -# hack to workaround limited dependencies tracking. -# TODO: fix this by pulling imports into the generated file. -from triton.language.target_info import current_target # noqa: F401 +from triton.experimental.gluon.language.amd.gfx1250 import wmma as amd_wmma +from triton.experimental.gluon.language.amd.gfx1250 import tdm as amd_tdm +from triton.experimental.gluon.language.amd.cdna3 import mfma as amd_mfma @gluon.constexpr_function @@ -187,11 +187,17 @@ def tl_dot( max_num_imprecise_acc=None, out_dtype=ttgl.float32, ): - num_warps: ttgl.constexpr = ttgl.num_warps() - if tl_dot_mmav5_supported(a.type, b.type, num_warps, input_precision, allow_tf32, max_num_imprecise_acc): - return tl_dot_blackwell(a, b, acc, input_precision, allow_tf32, max_num_imprecise_acc, out_dtype) + target: ttgl.constexpr = current_target() + if _is_gfx1250(target): + return tl_dot_wmma(a, b, acc, out_dtype) + elif _is_cdna(target): + return tl_dot_mfma(a, b, acc, out_dtype) else: - return tl_dot_mma_sync(a, b, acc, input_precision, out_dtype) + num_warps: ttgl.constexpr = ttgl.num_warps() + if tl_dot_mmav5_supported(a.type, b.type, num_warps, input_precision, allow_tf32, max_num_imprecise_acc): + return tl_dot_blackwell(a, b, acc, input_precision, allow_tf32, max_num_imprecise_acc, out_dtype) + else: + return tl_dot_mma_sync(a, b, acc, input_precision, out_dtype) @gluon.constexpr_function @@ -339,8 +345,8 @@ def tl_dot_scaled( rhs_k_pack=True, out_dtype=ttgl.float32, ): - if (tl_dot_scaled_mmav5_supported(lhs.type, rhs.type, ttgl.num_warps()) and lhs_scale is not None - and rhs_scale is not None): + if (_is_nvidia(current_target()) and tl_dot_scaled_mmav5_supported(lhs.type, rhs.type, ttgl.num_warps()) + and lhs_scale is not None and rhs_scale is not None): return tl_dot_scaled_blackwell( lhs, lhs_scale, @@ -499,28 +505,29 @@ def tl_dot_scaled_blackwell( @gluon.constexpr_function -def get_num_threads_per_warp() -> ttgl.constexpr: +def get_num_threads_per_warp(target=None) -> ttgl.constexpr: + if target is None: + target = current_target() + if target is not None and target.backend == "hip": + gfx_major = int(target.arch[3:-2]) + return ttgl.constexpr(32 if gfx_major >= 10 else 64) return ttgl.constexpr(32) @gluon.jit def get_num_threads_per_program(): - return ttgl.num_warps() * get_num_threads_per_warp() + return ttgl.num_warps() * get_num_threads_per_warp(current_target()) @gluon.constexpr_function -def default_blocked_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr) -> ttgl.constexpr: +def default_blocked_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr, target=None) -> ttgl.constexpr: rank = len(shape) # 1 element per thread for all dimensions size_per_thread = [1] * rank - # Distribute 32 threads per warp across dimensions (simple heuristic: last-fastest) + # Distribute threads per warp across dimensions (simple heuristic: last-fastest) threads_per_warp = [1] * rank # TODO: pick a better layout based on shape. Using this allows to not have to convert layout when broadcasting but may blow up register pressure. - threads_per_warp[rank - 1] = get_num_threads_per_warp() - # remaining_threads = get_num_threads_per_warp() - # for dim in range(rank - 1, -1, -1): - # threads_per_warp[dim] = min(remaining_threads, shape[dim]) - # remaining_threads = remaining_threads // threads_per_warp[dim] + threads_per_warp[rank - 1] = get_num_threads_per_warp(target) # Use provided num_warps to distribute warps per CTA (put all on first dim) warps_per_cta = [1] * rank warps_per_cta[0] = num_warps @@ -534,10 +541,186 @@ def default_blocked_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr) -> ) +# ---- architecture detection ---- + + +@gluon.constexpr_function +def _is_nvidia(target=None): + return target is None or target.backend == "cuda" + + +@gluon.constexpr_function +def _is_gfx1250(target=None): + return target is not None and target.arch == "gfx1250" + + +@gluon.constexpr_function +def _is_cdna(target=None): + return target is not None and target.arch in ("gfx942", "gfx950") + + +@gluon.constexpr_function +def _cdna_version(target=None): + """Returns 3 for gfx942, 4 for gfx950.""" + return 4 if target is not None and target.arch == "gfx950" else 3 + + +# ---- AMD WMMA layout helpers (gfx1250) ---- + + +@gluon.constexpr_function +def compute_warp_bases(num_warps): + """Distribute warps across M/N: first bit to N, rest to M.""" + n_bits = int(math.log2(num_warps)) + if n_bits == 0: + return [] + warp_bases = [[0, 1]] + for i in range(n_bits - 1): + warp_bases.append([1 << i, 0]) + return warp_bases + + +@gluon.constexpr_function +def get_wmma_layout(shape, num_warps): + warp_bases = compute_warp_bases(num_warps) + return ttgl.amd.AMDWMMALayout(3, True, warp_bases, [], [16, 16, 32]) + + +@gluon.constexpr_function +def get_wmma_k_width(a_ty, b_ty): + min_bitwidth = min(a_ty.element_ty.primitive_bitwidth, b_ty.element_ty.primitive_bitwidth) + return max(128 // min_bitwidth, 1) + + +# ---- AMD MFMA layout helpers (cdna3/cdna4) ---- + + +@gluon.constexpr_function +def get_mfma_instr_k(element_bitwidth, target=None): + """K dimension of the MFMA instruction for [32, 32, K].""" + k_bits = 128 if _cdna_version(target) == 3 else 256 + return k_bits // element_bitwidth + + +@gluon.constexpr_function +def get_mfma_layout(num_warps, element_bitwidth, target=None): + instr_k = get_mfma_instr_k(element_bitwidth, target) + return ttgl.amd.AMDMFMALayout( + version=_cdna_version(target), + instr_shape=[32, 32, instr_k], + transposed=True, + warps_per_cta=[num_warps, 1], + ) + + +@gluon.constexpr_function +def get_mfma_k_width(a_ty, b_ty, target=None): + min_bitwidth = min(a_ty.element_ty.primitive_bitwidth, b_ty.element_ty.primitive_bitwidth) + instr_k = get_mfma_instr_k(min_bitwidth, target) + return instr_k // 2 + + +# ---- AMD dot paths ---- + + +@gluon.jit +def tl_dot_wmma(a, b, acc, out_dtype): + """gfx1250 WMMA path.""" + M: ttgl.constexpr = a.type.shape[0] + N: ttgl.constexpr = b.type.shape[1] + num_warps: ttgl.constexpr = ttgl.num_warps() + + wmma_layout: ttgl.constexpr = get_wmma_layout([M, N], num_warps) + k_width: ttgl.constexpr = get_wmma_k_width(a.type, b.type) + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=wmma_layout, k_width=k_width) + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=wmma_layout, k_width=k_width) + + a = ttgl.convert_layout(a, a_layout) + b = ttgl.convert_layout(b, b_layout) + + if acc is not None: + accumulator = ttgl.convert_layout(acc, wmma_layout) + else: + accumulator = ttgl.zeros([M, N], out_dtype, layout=wmma_layout) + + result = amd_wmma(a, b, accumulator) + + if acc is not None: + ret_layout: ttgl.constexpr = acc.type.layout + else: + ret_layout: ttgl.constexpr = default_blocked_layout(result.type.shape, num_warps) + return ttgl.convert_layout(result, ret_layout) + + +@gluon.jit +def tl_dot_mfma(a, b, acc, out_dtype): + """CDNA3/CDNA4 MFMA path.""" + M: ttgl.constexpr = a.type.shape[0] + N: ttgl.constexpr = b.type.shape[1] + num_warps: ttgl.constexpr = ttgl.num_warps() + min_bitwidth: ttgl.constexpr = min(a.type.element_ty.primitive_bitwidth, b.type.element_ty.primitive_bitwidth) + target: ttgl.constexpr = current_target() + + mfma_layout: ttgl.constexpr = get_mfma_layout(num_warps, min_bitwidth, target) + k_width: ttgl.constexpr = get_mfma_k_width(a.type, b.type, target) + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=k_width) + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=k_width) + + a = ttgl.convert_layout(a, a_layout) + b = ttgl.convert_layout(b, b_layout) + + if acc is not None: + accumulator = ttgl.convert_layout(acc, mfma_layout) + else: + accumulator = ttgl.zeros([M, N], out_dtype, layout=mfma_layout) + + result = amd_mfma(a, b, accumulator) + + if acc is not None: + ret_layout: ttgl.constexpr = acc.type.layout + else: + ret_layout: ttgl.constexpr = default_blocked_layout(result.type.shape, num_warps) + return ttgl.convert_layout(result, ret_layout) + + +# ---- AMD TDM tensor descriptors (gfx1250 only) ---- + + +@gluon.constexpr_function +def get_default_tdm_layout(block_shape, element_bitwidth): + return ttgl.PaddedSharedLayout.with_identity_for( + [[block_shape[-1], 4]], + list(block_shape), + [1, 0], + ) + + +@gluon.jit +def tl_load_tensor_descriptor_amd(desc, offsets): + smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout) + amd_tdm.async_load(desc, offsets, smem) + amd_tdm.async_wait(0) + ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) + out = smem.load(ret_layout) + return out + + +@gluon.jit +def tl_store_tensor_descriptor_amd(desc, offsets, value): + smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value) + amd_tdm.async_store(desc, offsets, smem) + amd_tdm.async_wait(0) + + +# ---- obj dispatch (routes desc.load/store/gather/scatter to TMA or TDM) ---- + + @gluon.jit def tl_obj_store(obj, offsets, value): if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): return tl_store_tensor_descriptor(obj, offsets, value) + elif isinstance(obj, amd_tdm.tensor_descriptor): + return tl_store_tensor_descriptor_amd(obj, offsets, value) else: return obj.store(offsets, value) @@ -546,6 +729,8 @@ def tl_obj_store(obj, offsets, value): def tl_obj_load(obj, offsets): if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): return tl_load_tensor_descriptor(obj, offsets) + elif isinstance(obj, amd_tdm.tensor_descriptor): + return tl_load_tensor_descriptor_amd(obj, offsets) else: return obj.load(offsets) @@ -593,10 +778,66 @@ def tl_obj_scatter(obj, value, x_offsets, y_offset): obj.scatter(value, x_offsets, y_offset) -@gluon.jit -def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option: ttgl.constexpr = "zero"): - layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for(block_shape, base.dtype.element_ty) - return tma.make_tensor_descriptor(base, shape, strides, block_shape, layout, padding_option) +@ttgl._core.builtin +def tl_obj_gather_amd(desc, x_offsets, y_offset, _semantic=None, _generator=None): + # TDM gather: recreate descriptor with block_shape=[num_idx, block_n], then async gather. + # Triton's API creates descriptors with block_shape=[1, block_n], but TDM hardware requires + # block_shape to match the shared memory allocation [num_idx, block_n] for gather/scatter. + num_idx = x_offsets.shape[0] + block_n = desc.block_shape[1] + gather_shape = [num_idx, block_n] + smem_layout = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) + gather_desc = amd_tdm.make_tensor_descriptor(desc._tdm_base, list(desc._tdm_shape), list(desc._tdm_strides), + gather_shape, smem_layout, _semantic=_semantic) + num_warps = ttgl.num_warps(_semantic=_semantic, _generator=_generator) + idx_base = ttgl.BlockedLayout([num_idx, 1], [1, get_num_threads_per_warp(current_target())], [1, num_warps], [1, 0]) + idx_layout = ttgl.SliceLayout(1, idx_base) + x_offsets = ttgl.convert_layout(x_offsets, idx_layout, _semantic=_semantic) + alloc = ttgl.allocate_shared_memory(desc.dtype, gather_shape, smem_layout, _semantic=_semantic) + y_off = ttgl.to_tensor(y_offset, _semantic=_semantic) + amd_tdm.async_gather(gather_desc, x_offsets, y_off, alloc, _semantic=_semantic) + amd_tdm.async_wait(0, _semantic=_semantic) + ret_layout = default_blocked_layout(gather_shape, num_warps, current_target()) + out = alloc.load(ret_layout, _semantic=_semantic) + return out + + +@ttgl._core.builtin +def tl_obj_scatter_amd(desc, value, x_offsets, y_offset, _semantic=None, _generator=None): + # TDM scatter: recreate descriptor with block_shape=[num_idx, block_n], then async scatter. + # Triton's API creates descriptors with block_shape=[1, block_n], but TDM hardware requires + # block_shape to match the shared memory allocation [num_idx, block_n] for gather/scatter. + num_idx = x_offsets.shape[0] + block_n = desc.block_shape[1] + scatter_shape = [num_idx, block_n] + smem_layout = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) + scatter_desc = amd_tdm.make_tensor_descriptor(desc._tdm_base, list(desc._tdm_shape), list(desc._tdm_strides), + scatter_shape, smem_layout, _semantic=_semantic) + num_warps = ttgl.num_warps(_semantic=_semantic, _generator=_generator) + idx_base = ttgl.BlockedLayout([num_idx, 1], [1, get_num_threads_per_warp(current_target())], [1, num_warps], [1, 0]) + idx_layout = ttgl.SliceLayout(1, idx_base) + x_offsets = ttgl.convert_layout(x_offsets, idx_layout, _semantic=_semantic) + alloc = ttgl.allocate_shared_memory(desc.dtype, scatter_shape, smem_layout, value, _semantic=_semantic) + y_off = ttgl.to_tensor(y_offset, _semantic=_semantic) + amd_tdm.async_scatter(scatter_desc, x_offsets, y_off, alloc, _semantic=_semantic) + amd_tdm.async_wait(0, _semantic=_semantic) + + +@ttgl._core.builtin +def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option="zero", _semantic=None): + if _is_gfx1250(current_target()): + element_bitwidth = base.dtype.element_ty.primitive_bitwidth + layout = get_default_tdm_layout(block_shape, element_bitwidth) + desc = amd_tdm.make_tensor_descriptor(base, shape, strides, block_shape, layout, _semantic=_semantic) + # Stash construction args so tl_obj_gather_amd/tl_obj_scatter_amd can recreate the + # descriptor with a different block_shape. TDM gather/scatter require block_shape to + # match [num_idx, block_n], but Triton creates descriptors with block_shape=[1, block_n]. + desc._tdm_base = base + desc._tdm_shape = shape + desc._tdm_strides = strides + return desc + layout = ttgl.NVMMASharedLayout.get_default_for(block_shape, base.dtype.element_ty) + return tma.make_tensor_descriptor(base, shape, strides, block_shape, layout, padding_option, _semantic=_semantic) @gluon.jit @@ -683,12 +924,12 @@ def reset_to_default_layout(value): @gluon.constexpr_function -def get_split_src_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr) -> ttgl.constexpr: +def get_split_src_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr, target=None) -> ttgl.constexpr: rank = len(shape) size_per_thread = [1 if i != rank - 1 else 2 for i in range(rank)] - # Distribute 32 threads per warp across dimensions (simple heuristic: last-fastest) + # Distribute threads per warp across dimensions (simple heuristic: last-fastest) threads_per_warp = [1 for _ in range(rank)] - remaining_threads = get_num_threads_per_warp() + remaining_threads = get_num_threads_per_warp(target) for dim in range(rank - 2, -1, -1): threads_per_warp[dim] = min(shape[dim], remaining_threads) remaining_threads = remaining_threads // threads_per_warp[dim] @@ -727,6 +968,13 @@ def torch_dtype_to_triton(dtype): assert isinstance(desc, TensorDescriptor) block_shape = desc.block_shape dtype = desc.base.dtype + + target = current_target() + if target is not None and target.backend == "hip" and target.arch == "gfx1250": + element_bitwidth = torch_dtype_to_triton(dtype).primitive_bitwidth + layout = get_default_tdm_layout(block_shape, element_bitwidth) + return gluon.amd.gfx1250.TensorDescriptor(desc.base, list(desc.shape), list(desc.strides), block_shape, layout) + tensor = desc.base layout = ttgl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(dtype)) return gluon.nvidia.hopper.TensorDescriptor(tensor, desc.shape, desc.strides, block_shape, layout) @@ -746,7 +994,58 @@ def build_expand_dims_layout(shape, expand_dims, num_warps): return layout -@gluon.jit -def convert_to_expand_dims_layout(value, expand_dims: list[int]) -> Any: - layout: ttgl.constexpr = build_expand_dims_layout(value.shape, expand_dims, ttgl.num_warps()) - return ttgl.convert_layout(value, layout) +@ttgl._core.builtin +def convert_to_expand_dims_layout(value, expand_dims: list[int], _semantic=None, _generator=None) -> Any: + parent_shape = _unwrap_if_constexpr(value.type.shape) + if isinstance(parent_shape, ttgl.tuple): + parent_shape = parent_shape.values + assert isinstance(parent_shape, + list), (f"expected parent shape to be a list, got {parent_shape} which is {type(parent_shape)}") + + num_warps = ttgl.num_warps(_semantic=_semantic, _generator=_generator) + layout = build_expand_dims_layout(parent_shape, expand_dims, num_warps) + return ttgl.convert_layout(value, layout, _semantic=_semantic) + + +@ttgl._core.builtin +def tl_atomic_add(ptr, val, mask=None, sem=None, scope=None, _semantic=None): + if ptr.type.is_block(): + if isinstance(val, ttgl.constexpr) or not val.type.is_block(): + val = ttgl.to_tensor(val, _semantic=_semantic) + val = ttgl.full(ptr.shape, val, val.dtype, ptr.type.layout, _semantic=_semantic) + if mask is not None and isinstance(mask, ttgl.constexpr) or not mask.type.is_block(): + mask = ttgl.to_tensor(mask, _semantic=_semantic) + mask = ttgl.full(ptr.shape, mask, mask.dtype, ptr.type.layout, _semantic=_semantic) + return ttgl.atomic_add(ptr, val=val, mask=mask, sem=sem, scope=scope, _semantic=_semantic) + + +# Module-level target, set by the translator via _make_target(). +# Falls back to the active driver's target if not set. +_current_target = None + + +def current_target(): + if _current_target is not None: + return _current_target + from triton.runtime import driver + + try: + active_driver = driver.active + except RuntimeError: + return None + return active_driver.get_current_target() + + +current_target.__triton_builtin__ = True + + +def _make_target(arch): + """Construct a GPUTarget from an architecture string (e.g. 'gfx1250', 'nvidia').""" + if arch.startswith("gfx"): + from triton.backends.amd.compiler import GPUTarget + warp_size = 32 if int(arch[3:-2]) >= 10 else 64 + return GPUTarget("hip", arch, warp_size) + return None + + +_make_target.__triton_builtin__ = True From e4bb60f15966de6a75138903e1d5efa00fe59316 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 24 Mar 2026 00:23:28 -0700 Subject: [PATCH 02/26] Remove unnecessary null checks in getTensorDescMetadata --- python/src/ir.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/src/ir.cc b/python/src/ir.cc index 604d80b028ca..a0bc030814bc 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -187,7 +187,7 @@ py::list getTensorDescMetadata(ModuleOp &mod) { auto encoding = blockType.getEncoding(); py::dict metadata; - if (isa_and_nonnull(encoding)) { + if (isa(encoding)) { auto mmaEncoding = dyn_cast(encoding); auto swizzle = ttng::getTMASwizzleMode(arg.getLoc(), descTy); auto elemType = ttng::getTMAElementType(arg.getLoc(), descTy); @@ -210,7 +210,7 @@ py::list getTensorDescMetadata(ModuleOp &mod) { metadata["elem_bits"] = blockType.getElementTypeBitWidth(); if (auto paddedEnc = - dyn_cast_or_null(encoding)) { + dyn_cast(encoding)) { py::list intervalPaddingPairs; for (auto [interval, padding] : llvm::zip_equal( paddedEnc.getIntervals(), paddedEnc.getPaddings())) { From e0d2031c7acd98049205db6a5bb4a333f76177ca Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 24 Mar 2026 01:32:37 -0700 Subject: [PATCH 03/26] pre-commit run --- python/src/ir.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/src/ir.cc b/python/src/ir.cc index a0bc030814bc..db769960eb87 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -209,8 +209,7 @@ py::list getTensorDescMetadata(ModuleOp &mod) { std::vector(blockShape.begin(), blockShape.end()); metadata["elem_bits"] = blockType.getElementTypeBitWidth(); - if (auto paddedEnc = - dyn_cast(encoding)) { + if (auto paddedEnc = dyn_cast(encoding)) { py::list intervalPaddingPairs; for (auto [interval, padding] : llvm::zip_equal( paddedEnc.getIntervals(), paddedEnc.getPaddings())) { From cee7ac81707b105893f55eef235833088a4a480e Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 24 Mar 2026 03:01:21 -0700 Subject: [PATCH 04/26] Fix NVIDIA translated dot test failure --- python/test/unit/tools/test_triton_to_gluon.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/test/unit/tools/test_triton_to_gluon.py b/python/test/unit/tools/test_triton_to_gluon.py index 47adad8c6f1c..3c5bc61509b3 100644 --- a/python/test/unit/tools/test_triton_to_gluon.py +++ b/python/test/unit/tools/test_triton_to_gluon.py @@ -17,9 +17,18 @@ "gfx950": is_hip_cdna4, } +_dot_targets = { + "nvidia": is_blackwell, + "gfx1250": is_hip_gfx1250, + "gfx942": is_hip_cdna3, + "gfx950": is_hip_cdna4, +} -def _skip_unless_target(target): - if not _all_targets[target](): + +def _skip_unless_target(target, targets=_all_targets): + """Skip test if the required hardware for the given target is not available. + Specify targets for tests that require specific targets e.g. Blackwell on NVIDIA or gfx1250 on AMD.""" + if not targets[target](): pytest.skip(f"Requires {target}") @@ -83,9 +92,9 @@ def matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M: tl.constexpr, BLOCK_N: tl.c impl_matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M, BLOCK_N, BLOCK_K) -@pytest.mark.parametrize("target", _all_targets.keys()) +@pytest.mark.parametrize("target", _dot_targets.keys()) def test_triton_to_gluon_dot_minimal(tmp_path, target): - _skip_unless_target(target) + _skip_unless_target(target, _dot_targets) kernel = convert_kernel(matmul_tile_kernel, "matmul_tile_kernel", tmp_path, target=target) M, N, K = 128, 128, 128 a = torch.randn((M, K), device="cuda", dtype=torch.float16) From f11be112c03c71034b254b0b4c06dbd166696c59 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 24 Mar 2026 06:02:40 -0700 Subject: [PATCH 05/26] Fixes with review --- .../test/unit/tools/test_triton_to_gluon.py | 7 +++-- .../translator_helpers.py | 31 ++++--------------- 2 files changed, 10 insertions(+), 28 deletions(-) diff --git a/python/test/unit/tools/test_triton_to_gluon.py b/python/test/unit/tools/test_triton_to_gluon.py index 3c5bc61509b3..53c293e43881 100644 --- a/python/test/unit/tools/test_triton_to_gluon.py +++ b/python/test/unit/tools/test_triton_to_gluon.py @@ -438,9 +438,10 @@ def atomic_add_kernel(out_ptr, BLOCK: tl.constexpr): tl.atomic_add(out_ptr + idx, idx, mask=scalar_mask, sem="release", scope="cta") -@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA") -def test_atomic_add(tmp_path): - kernel = convert_kernel(atomic_add_kernel, "atomic_add_kernel", tmp_path) +@pytest.mark.parametrize("target", _all_targets.keys()) +def test_atomic_add(tmp_path, target): + _skip_unless_target(target) + kernel = convert_kernel(atomic_add_kernel, "atomic_add_kernel", tmp_path, target=target) block = 32 * 4 ref = torch.zeros((block, ), device="cuda") diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index da69dc5aab44..8cc582e77432 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -968,14 +968,14 @@ def torch_dtype_to_triton(dtype): assert isinstance(desc, TensorDescriptor) block_shape = desc.block_shape dtype = desc.base.dtype + tensor = desc.base target = current_target() if target is not None and target.backend == "hip" and target.arch == "gfx1250": element_bitwidth = torch_dtype_to_triton(dtype).primitive_bitwidth layout = get_default_tdm_layout(block_shape, element_bitwidth) - return gluon.amd.gfx1250.TensorDescriptor(desc.base, list(desc.shape), list(desc.strides), block_shape, layout) + return gluon.amd.gfx1250.TensorDescriptor(tensor, list(desc.shape), list(desc.strides), block_shape, layout) - tensor = desc.base layout = ttgl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(dtype)) return gluon.nvidia.hopper.TensorDescriptor(tensor, desc.shape, desc.strides, block_shape, layout) @@ -994,29 +994,10 @@ def build_expand_dims_layout(shape, expand_dims, num_warps): return layout -@ttgl._core.builtin -def convert_to_expand_dims_layout(value, expand_dims: list[int], _semantic=None, _generator=None) -> Any: - parent_shape = _unwrap_if_constexpr(value.type.shape) - if isinstance(parent_shape, ttgl.tuple): - parent_shape = parent_shape.values - assert isinstance(parent_shape, - list), (f"expected parent shape to be a list, got {parent_shape} which is {type(parent_shape)}") - - num_warps = ttgl.num_warps(_semantic=_semantic, _generator=_generator) - layout = build_expand_dims_layout(parent_shape, expand_dims, num_warps) - return ttgl.convert_layout(value, layout, _semantic=_semantic) - - -@ttgl._core.builtin -def tl_atomic_add(ptr, val, mask=None, sem=None, scope=None, _semantic=None): - if ptr.type.is_block(): - if isinstance(val, ttgl.constexpr) or not val.type.is_block(): - val = ttgl.to_tensor(val, _semantic=_semantic) - val = ttgl.full(ptr.shape, val, val.dtype, ptr.type.layout, _semantic=_semantic) - if mask is not None and isinstance(mask, ttgl.constexpr) or not mask.type.is_block(): - mask = ttgl.to_tensor(mask, _semantic=_semantic) - mask = ttgl.full(ptr.shape, mask, mask.dtype, ptr.type.layout, _semantic=_semantic) - return ttgl.atomic_add(ptr, val=val, mask=mask, sem=sem, scope=scope, _semantic=_semantic) +@gluon.jit +def convert_to_expand_dims_layout(value, expand_dims: list[int]) -> Any: + layout: ttgl.constexpr = build_expand_dims_layout(value.shape, expand_dims, ttgl.num_warps()) + return ttgl.convert_layout(value, layout) # Module-level target, set by the translator via _make_target(). From 99b3652e38fe3ac739492a4d12ead85b7418e77f Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 24 Mar 2026 06:44:08 -0700 Subject: [PATCH 06/26] Remove redundant descriptor test and combine with existing one --- .../test/unit/tools/test_triton_to_gluon.py | 39 +++++-------------- .../translator_helpers.py | 3 ++ 2 files changed, 13 insertions(+), 29 deletions(-) diff --git a/python/test/unit/tools/test_triton_to_gluon.py b/python/test/unit/tools/test_triton_to_gluon.py index 53c293e43881..73d1fdd4148e 100644 --- a/python/test/unit/tools/test_triton_to_gluon.py +++ b/python/test/unit/tools/test_triton_to_gluon.py @@ -27,7 +27,7 @@ def _skip_unless_target(target, targets=_all_targets): """Skip test if the required hardware for the given target is not available. - Specify targets for tests that require specific targets e.g. Blackwell on NVIDIA or gfx1250 on AMD.""" + Use _dot_targets for dot tests (Blackwell on NVIDIA), _descriptor_targets for descriptor tests.""" if not targets[target](): pytest.skip(f"Requires {target}") @@ -256,9 +256,13 @@ def make_tensor_descriptor_copy_kernel(x_ptr, y_ptr, M, N, BLOCK_M: tl.constexpr out_desc.store([0, 0], tile) -@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer") -def test_triton_to_gluon_make_tensor_descriptor(tmp_path, with_allocator): - kernel = convert_kernel(make_tensor_descriptor_copy_kernel, "make_tensor_descriptor_copy_kernel", tmp_path) +# Parametrized over _descriptor_targets: tests tl.make_tensor_descriptor translation +# for both NVIDIA TMA (Hopper+) and AMD TDM (gfx1250). +@pytest.mark.parametrize("target", _descriptor_targets.keys()) +def test_triton_to_gluon_make_tensor_descriptor(tmp_path, target, with_allocator): + _skip_unless_target(target, _descriptor_targets) + kernel = convert_kernel(make_tensor_descriptor_copy_kernel, "make_tensor_descriptor_copy_kernel", tmp_path, + target=target) M = N = 64 x = torch.randn((M, N), device="cuda", dtype=torch.float16) @@ -480,31 +484,6 @@ def test_cat(tmp_path, target): torch.testing.assert_close(sorted(out.cpu()), sorted(ref.cpu()), atol=0, rtol=0) -@triton.jit -def make_desc_copy_kernel(in_ptr, out_ptr, M, N, stride_m, stride_n, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - in_desc = tl.make_tensor_descriptor(in_ptr, shape=[M, N], strides=[stride_m, stride_n], - block_shape=[BLOCK_M, BLOCK_N]) - out_desc = tl.make_tensor_descriptor(out_ptr, shape=[M, N], strides=[stride_m, stride_n], - block_shape=[BLOCK_M, BLOCK_N]) - tile = in_desc.load([0, 0]) - out_desc.store([0, 0], tile) - - -@pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires gfx1250") -def test_make_tensor_descriptor_gfx1250(tmp_path): - kernel = convert_kernel(make_desc_copy_kernel, "make_desc_copy_kernel", tmp_path, target="gfx1250") - - M, N = 64, 64 - x = torch.randn((M, N), device="cuda", dtype=torch.float16) - y = torch.zeros((M, N), device="cuda", dtype=torch.float16) - grid = (1, ) - kernel[grid](x, y, M, N, x.stride(0), x.stride(1), M, N) - - y_ref = torch.zeros_like(y) - make_desc_copy_kernel[grid](x, y_ref, M, N, x.stride(0), x.stride(1), M, N) - torch.testing.assert_close(y, y_ref, atol=0, rtol=0) - - @triton.jit def gather_scatter_roundtrip_kernel(out_ptr, in_ptr, idx_ptr, X: tl.constexpr, Y: tl.constexpr, BLOCK_X: tl.constexpr, BLOCK_Y: tl.constexpr): @@ -515,6 +494,8 @@ def gather_scatter_roundtrip_kernel(out_ptr, in_ptr, idx_ptr, X: tl.constexpr, Y out_desc.scatter(data, idx, 0) +# TODO: parametrize over _descriptor_targets once NVIDIA gather/scatter translation is supported. +# The translator currently routes gather/scatter to AMD-specific helpers for AMD targets only. @pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires gfx1250") def test_gather_scatter_roundtrip(tmp_path): kernel = convert_kernel(gather_scatter_roundtrip_kernel, "gather_scatter_roundtrip_kernel", tmp_path, diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index 8cc582e77432..46f4db9b038d 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -780,6 +780,7 @@ def tl_obj_scatter(obj, value, x_offsets, y_offset): @ttgl._core.builtin def tl_obj_gather_amd(desc, x_offsets, y_offset, _semantic=None, _generator=None): + # Builtin required: reads desc._tdm_* attributes stashed by tl_make_tensor_descriptor. # TDM gather: recreate descriptor with block_shape=[num_idx, block_n], then async gather. # Triton's API creates descriptors with block_shape=[1, block_n], but TDM hardware requires # block_shape to match the shared memory allocation [num_idx, block_n] for gather/scatter. @@ -804,6 +805,7 @@ def tl_obj_gather_amd(desc, x_offsets, y_offset, _semantic=None, _generator=None @ttgl._core.builtin def tl_obj_scatter_amd(desc, value, x_offsets, y_offset, _semantic=None, _generator=None): + # Builtin required: reads desc._tdm_* attributes stashed by tl_make_tensor_descriptor. # TDM scatter: recreate descriptor with block_shape=[num_idx, block_n], then async scatter. # Triton's API creates descriptors with block_shape=[1, block_n], but TDM hardware requires # block_shape to match the shared memory allocation [num_idx, block_n] for gather/scatter. @@ -825,6 +827,7 @@ def tl_obj_scatter_amd(desc, value, x_offsets, y_offset, _semantic=None, _genera @ttgl._core.builtin def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option="zero", _semantic=None): + # Builtin required: attribute assignment (desc._tdm_*) is not supported in @gluon.jit. if _is_gfx1250(current_target()): element_bitwidth = base.dtype.element_ty.primitive_bitwidth layout = get_default_tdm_layout(block_shape, element_bitwidth) From cd9dbd9f884c2e4c8160f44a8caf05168f360c77 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 25 Mar 2026 08:53:44 -0700 Subject: [PATCH 07/26] [Tools][Translator] Use @_aggregate for AMD tensor descriptors instead of @builtin Replace attribute-stashing @ttgl._core.builtin with a @tl.core._aggregate (AMDTensorDescriptorArgs) holding desc + base_ptr. Load/store use desc directly via @gluon.jit (generic 1D-5D). Gather/scatter reconstruct the descriptor with [num_indices, N] block_shape using desc.block_shape (plain ints from type metadata) and base_ptr, avoiding constexpr-to-tensor conversion in JIT list literals. Only _create_tdm_descriptor remains as a thin builtin for block_shape list construction. Made-with: Cursor --- .../triton_to_gluon_translator/translator.py | 43 ++++- .../translator_helpers.py | 154 ++++++++++-------- 2 files changed, 131 insertions(+), 66 deletions(-) diff --git a/python/triton/tools/triton_to_gluon_translator/translator.py b/python/triton/tools/triton_to_gluon_translator/translator.py index 2ef193787615..39d137ecc126 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator.py +++ b/python/triton/tools/triton_to_gluon_translator/translator.py @@ -109,6 +109,7 @@ def add_expr_rewrites(rewrites: list[RewriteFn]) -> None: class Translator(ReferenceRewriter): tensor_member_match_fns: list[str] = field(default_factory=list) target: str = "nvidia" + _desc_block_shapes: dict = field(default_factory=dict) def __post_init__(self) -> None: import triton @@ -150,6 +151,23 @@ def uncanonicalize_call(self, node: ast.Call, fn_name: str | None) -> ast.Call: new_callable = ast.Attribute(value, fn_name, ctx=ast.Load()) return ast.Call(func=new_callable, args=node.args[1:], keywords=node.keywords) + def visit_Assign(self, node: ast.Assign) -> ast.AST: + if self.target.startswith("gfx") and isinstance(node.value, ast.Call): + ref = self.get_reference(node.value.func) if isinstance(node.value.func, + (ast.Attribute, ast.Name)) else None + if ref is not None and ref[0] is tl.make_tensor_descriptor: + for target in node.targets: + if isinstance(target, ast.Name): + block_shape_arg = None + for kw in node.value.keywords: + if kw.arg == "block_shape": + block_shape_arg = kw.value + if block_shape_arg is None and len(node.value.args) >= 4: + block_shape_arg = node.value.args[3] + if block_shape_arg is not None: + self._desc_block_shapes[target.id] = ast.unparse(block_shape_arg) + return self.generic_visit(node) + def visit_Call(self, node: ast.Call) -> ast.AST: node, canonicalized = self.canonicalize_call(node) ref = self.get_reference(node.func) @@ -162,7 +180,27 @@ def visit_Call(self, node: ast.Call) -> ast.AST: "scatter", ]: attr = node.func.attr - # Use AMD-specific helpers for gather/scatter on AMD targets + desc_var = node.func.value + # Route AMD descriptor ops to TDM-specific helpers. + if self.target.startswith("gfx") and isinstance(desc_var, ast.Name): + if desc_var.id in self._desc_block_shapes: + if attr in ["load", "store"]: + # Load/store use the descriptor directly (via AMDTensorDescriptorArgs aggregate). + helper_name = f"tl_{attr}_tensor_descriptor_amd" + new_callee = parse_expr(f"helpers.{helper_name}") + node = ast.Call(func=new_callee, args=[node.func.value] + node.args, keywords=node.keywords) + return self.generic_visit(node) + elif attr in ["gather", "scatter"]: + # Gather/scatter reconstruct the descriptor with [num_idx, block_n] + # block_shape since TDM requires it to match the actual operation dims. + helper_name = f"tl_obj_{attr}_amd" + new_callee = parse_expr(f"helpers.{helper_name}") + idx_arg = node.args[0] + num_idx_node = parse_expr(f"{ast.unparse(idx_arg)}.shape[0]") + node = ast.Call(func=new_callee, args=[node.func.value] + node.args + [num_idx_node], + keywords=node.keywords) + return self.generic_visit(node) + # Use AMD-specific helpers for gather/scatter on AMD targets (host descriptor path). if attr in ["gather", "scatter"] and self.target.startswith("gfx"): helper_name = f"tl_obj_{attr}_amd" else: @@ -171,6 +209,9 @@ def visit_Call(self, node: ast.Call) -> ast.AST: node = ast.Call(func=new_callee, args=[node.func.value] + node.args, keywords=node.keywords) return self.generic_visit(node) value, _, _ = ref + if value is tl.make_tensor_descriptor and self.target.startswith("gfx"): + node.func = parse_expr("helpers.tl_make_tensor_descriptor_amd") + node.keywords = [kw for kw in node.keywords if kw.arg != "padding_option"] if value in [tl.reshape, tl.ravel]: node.keywords = [kw for kw in node.keywords if kw.arg != "can_reorder"] elif value is tl.split: diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index 46f4db9b038d..fd2953fd3be1 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -3,6 +3,7 @@ import math from typing import Any +import triton.language as tl from triton.experimental import gluon from triton.experimental.gluon import language as ttgl from triton.experimental.gluon.language.nvidia.ampere import mma_v2 @@ -686,17 +687,44 @@ def tl_dot_mfma(a, b, acc, out_dtype): # ---- AMD TDM tensor descriptors (gfx1250 only) ---- +# Builtin because list literals in @gluon.jit convert constexpr values to tensors, +# but block_shape requires constexpr ints. Varargs also aren't supported in JIT. +@ttgl._core.builtin +def _create_tdm_descriptor(base, shape, strides, *block_shape_and_layout, _semantic=None): + layout = block_shape_and_layout[-1] + block_shape = list(block_shape_and_layout[:-1]) + return amd_tdm.make_tensor_descriptor(base, shape, strides, block_shape, layout, _semantic=_semantic) + + @gluon.constexpr_function -def get_default_tdm_layout(block_shape, element_bitwidth): +def get_default_tdm_layout(*block_shape): + block_shape = list(block_shape) return ttgl.PaddedSharedLayout.with_identity_for( [[block_shape[-1], 4]], - list(block_shape), - [1, 0], + block_shape, + list(range(len(block_shape) - 1, -1, -1)), ) +@tl.core._aggregate +class AMDTensorDescriptorArgs: + """Wraps a real TDM descriptor alongside the original base pointer. + + The base_ptr is needed by gather/scatter to recreate the descriptor with a different + block_shape -- Triton uses block_shape=[1, N] but TDM hardware requires [num_indices, N]. + Shape, strides, and block_shape are read from desc (type metadata gives plain Python ints + for block_shape, tuples for shape/strides).""" + desc: amd_tdm.tensor_descriptor + base_ptr: tl.core.tensor + + @gluon.jit -def tl_load_tensor_descriptor_amd(desc, offsets): +def tl_load_tensor_descriptor_amd(obj, offsets): + if isinstance(obj, AMDTensorDescriptorArgs): + desc = obj.desc + else: + # Real TDM descriptor from convert_host_descriptor (passed as kernel arg). + desc = obj smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout) amd_tdm.async_load(desc, offsets, smem) amd_tdm.async_wait(0) @@ -706,7 +734,12 @@ def tl_load_tensor_descriptor_amd(desc, offsets): @gluon.jit -def tl_store_tensor_descriptor_amd(desc, offsets, value): +def tl_store_tensor_descriptor_amd(obj, offsets, value): + if isinstance(obj, AMDTensorDescriptorArgs): + desc = obj.desc + else: + # Real TDM descriptor from convert_host_descriptor (passed as kernel arg). + desc = obj smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value) amd_tdm.async_store(desc, offsets, smem) amd_tdm.async_wait(0) @@ -778,69 +811,61 @@ def tl_obj_scatter(obj, value, x_offsets, y_offset): obj.scatter(value, x_offsets, y_offset) -@ttgl._core.builtin -def tl_obj_gather_amd(desc, x_offsets, y_offset, _semantic=None, _generator=None): - # Builtin required: reads desc._tdm_* attributes stashed by tl_make_tensor_descriptor. - # TDM gather: recreate descriptor with block_shape=[num_idx, block_n], then async gather. - # Triton's API creates descriptors with block_shape=[1, block_n], but TDM hardware requires - # block_shape to match the shared memory allocation [num_idx, block_n] for gather/scatter. - num_idx = x_offsets.shape[0] - block_n = desc.block_shape[1] - gather_shape = [num_idx, block_n] - smem_layout = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) - gather_desc = amd_tdm.make_tensor_descriptor(desc._tdm_base, list(desc._tdm_shape), list(desc._tdm_strides), - gather_shape, smem_layout, _semantic=_semantic) - num_warps = ttgl.num_warps(_semantic=_semantic, _generator=_generator) - idx_base = ttgl.BlockedLayout([num_idx, 1], [1, get_num_threads_per_warp(current_target())], [1, num_warps], [1, 0]) - idx_layout = ttgl.SliceLayout(1, idx_base) - x_offsets = ttgl.convert_layout(x_offsets, idx_layout, _semantic=_semantic) - alloc = ttgl.allocate_shared_memory(desc.dtype, gather_shape, smem_layout, _semantic=_semantic) - y_off = ttgl.to_tensor(y_offset, _semantic=_semantic) - amd_tdm.async_gather(gather_desc, x_offsets, y_off, alloc, _semantic=_semantic) - amd_tdm.async_wait(0, _semantic=_semantic) - ret_layout = default_blocked_layout(gather_shape, num_warps, current_target()) - out = alloc.load(ret_layout, _semantic=_semantic) +@gluon.jit +def tl_obj_gather_amd(desc_args, x_offsets, y_offset, NUM_IDX: ttgl.constexpr): + # Triton creates gather descriptors with block_shape=[1, block_n], but TDM hardware + # operates on the full batch, requiring block_shape=[num_indices, block_n]. + BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] + smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) + gather_desc = _create_tdm_descriptor(desc_args.base_ptr, desc_args.desc.shape, desc_args.desc.strides, NUM_IDX, + BLOCK_N, smem_layout) + num_warps: ttgl.constexpr = ttgl.num_warps() + gather_shape: ttgl.constexpr = gather_desc.block_shape + idx_base: ttgl.constexpr = ttgl.BlockedLayout([gather_shape[0], 1], + [1, get_num_threads_per_warp(current_target())], [1, num_warps], + [1, 0]) + idx_layout: ttgl.constexpr = ttgl.SliceLayout(1, idx_base) + x_offsets = ttgl.convert_layout(x_offsets, idx_layout) + alloc = ttgl.allocate_shared_memory(desc_args.desc.dtype, list(gather_shape), smem_layout) + y_off = ttgl.to_tensor(y_offset) + amd_tdm.async_gather(gather_desc, x_offsets, y_off, alloc) + amd_tdm.async_wait(0) + ret_layout: ttgl.constexpr = default_blocked_layout(list(gather_shape), num_warps, current_target()) + out = alloc.load(ret_layout) return out -@ttgl._core.builtin -def tl_obj_scatter_amd(desc, value, x_offsets, y_offset, _semantic=None, _generator=None): - # Builtin required: reads desc._tdm_* attributes stashed by tl_make_tensor_descriptor. - # TDM scatter: recreate descriptor with block_shape=[num_idx, block_n], then async scatter. - # Triton's API creates descriptors with block_shape=[1, block_n], but TDM hardware requires - # block_shape to match the shared memory allocation [num_idx, block_n] for gather/scatter. - num_idx = x_offsets.shape[0] - block_n = desc.block_shape[1] - scatter_shape = [num_idx, block_n] - smem_layout = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) - scatter_desc = amd_tdm.make_tensor_descriptor(desc._tdm_base, list(desc._tdm_shape), list(desc._tdm_strides), - scatter_shape, smem_layout, _semantic=_semantic) - num_warps = ttgl.num_warps(_semantic=_semantic, _generator=_generator) - idx_base = ttgl.BlockedLayout([num_idx, 1], [1, get_num_threads_per_warp(current_target())], [1, num_warps], [1, 0]) - idx_layout = ttgl.SliceLayout(1, idx_base) - x_offsets = ttgl.convert_layout(x_offsets, idx_layout, _semantic=_semantic) - alloc = ttgl.allocate_shared_memory(desc.dtype, scatter_shape, smem_layout, value, _semantic=_semantic) - y_off = ttgl.to_tensor(y_offset, _semantic=_semantic) - amd_tdm.async_scatter(scatter_desc, x_offsets, y_off, alloc, _semantic=_semantic) - amd_tdm.async_wait(0, _semantic=_semantic) +@gluon.jit +def tl_obj_scatter_amd(desc_args, value, x_offsets, y_offset, NUM_IDX: ttgl.constexpr): + # See tl_obj_gather_amd for why the descriptor is recreated with a different block_shape. + BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] + smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) + scatter_desc = _create_tdm_descriptor(desc_args.base_ptr, desc_args.desc.shape, desc_args.desc.strides, NUM_IDX, + BLOCK_N, smem_layout) + num_warps: ttgl.constexpr = ttgl.num_warps() + scatter_shape: ttgl.constexpr = scatter_desc.block_shape + idx_base: ttgl.constexpr = ttgl.BlockedLayout([scatter_shape[0], 1], + [1, get_num_threads_per_warp(current_target())], [1, num_warps], + [1, 0]) + idx_layout: ttgl.constexpr = ttgl.SliceLayout(1, idx_base) + x_offsets = ttgl.convert_layout(x_offsets, idx_layout) + alloc = ttgl.allocate_shared_memory(desc_args.desc.dtype, list(scatter_shape), smem_layout, value) + y_off = ttgl.to_tensor(y_offset) + amd_tdm.async_scatter(scatter_desc, x_offsets, y_off, alloc) + amd_tdm.async_wait(0) -@ttgl._core.builtin -def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option="zero", _semantic=None): - # Builtin required: attribute assignment (desc._tdm_*) is not supported in @gluon.jit. - if _is_gfx1250(current_target()): - element_bitwidth = base.dtype.element_ty.primitive_bitwidth - layout = get_default_tdm_layout(block_shape, element_bitwidth) - desc = amd_tdm.make_tensor_descriptor(base, shape, strides, block_shape, layout, _semantic=_semantic) - # Stash construction args so tl_obj_gather_amd/tl_obj_scatter_amd can recreate the - # descriptor with a different block_shape. TDM gather/scatter require block_shape to - # match [num_idx, block_n], but Triton creates descriptors with block_shape=[1, block_n]. - desc._tdm_base = base - desc._tdm_shape = shape - desc._tdm_strides = strides - return desc - layout = ttgl.NVMMASharedLayout.get_default_for(block_shape, base.dtype.element_ty) - return tma.make_tensor_descriptor(base, shape, strides, block_shape, layout, padding_option, _semantic=_semantic) +@gluon.jit +def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option: ttgl.constexpr = "zero"): + layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for(block_shape, base.dtype.element_ty) + return tma.make_tensor_descriptor(base, shape, strides, block_shape, layout, padding_option) + + +@gluon.jit +def tl_make_tensor_descriptor_amd(base, shape, strides, block_shape): + layout: ttgl.constexpr = get_default_tdm_layout(*block_shape) + desc = _create_tdm_descriptor(base, shape, strides, *block_shape, layout) + return AMDTensorDescriptorArgs(desc, base) @gluon.jit @@ -975,8 +1000,7 @@ def torch_dtype_to_triton(dtype): target = current_target() if target is not None and target.backend == "hip" and target.arch == "gfx1250": - element_bitwidth = torch_dtype_to_triton(dtype).primitive_bitwidth - layout = get_default_tdm_layout(block_shape, element_bitwidth) + layout = get_default_tdm_layout(*block_shape) return gluon.amd.gfx1250.TensorDescriptor(tensor, list(desc.shape), list(desc.strides), block_shape, layout) layout = ttgl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(dtype)) From 4e02793fee835bceb075114cd7812f30896daba1 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 7 Apr 2026 01:59:43 -0700 Subject: [PATCH 08/26] fix for python 3.10 --- .../tools/triton_to_gluon_translator/translator_helpers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index fd2953fd3be1..436ccff25d23 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -1,7 +1,6 @@ # type: ignore import math -from typing import Any import triton.language as tl from triton.experimental import gluon @@ -1022,7 +1021,7 @@ def build_expand_dims_layout(shape, expand_dims, num_warps): @gluon.jit -def convert_to_expand_dims_layout(value, expand_dims: list[int]) -> Any: +def convert_to_expand_dims_layout(value, expand_dims: list[int]): layout: ttgl.constexpr = build_expand_dims_layout(value.shape, expand_dims, ttgl.num_warps()) return ttgl.convert_layout(value, layout) From 63e0cc1d1f15d4ace0e6a55e576dbe3460d53ed0 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 7 Apr 2026 02:39:08 -0700 Subject: [PATCH 09/26] [Tools][Translator] Remove _create_tdm_descriptor builtin Use constexpr annotations on block_shape parameters to prevent constexpr-to-tensor decay in JIT list literals. This eliminates the last builtin helper -- all functions are now @gluon.jit. Made-with: Cursor --- .../translator_helpers.py | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index 436ccff25d23..4e83f42c1d9d 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -686,15 +686,6 @@ def tl_dot_mfma(a, b, acc, out_dtype): # ---- AMD TDM tensor descriptors (gfx1250 only) ---- -# Builtin because list literals in @gluon.jit convert constexpr values to tensors, -# but block_shape requires constexpr ints. Varargs also aren't supported in JIT. -@ttgl._core.builtin -def _create_tdm_descriptor(base, shape, strides, *block_shape_and_layout, _semantic=None): - layout = block_shape_and_layout[-1] - block_shape = list(block_shape_and_layout[:-1]) - return amd_tdm.make_tensor_descriptor(base, shape, strides, block_shape, layout, _semantic=_semantic) - - @gluon.constexpr_function def get_default_tdm_layout(*block_shape): block_shape = list(block_shape) @@ -816,8 +807,9 @@ def tl_obj_gather_amd(desc_args, x_offsets, y_offset, NUM_IDX: ttgl.constexpr): # operates on the full batch, requiring block_shape=[num_indices, block_n]. BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) - gather_desc = _create_tdm_descriptor(desc_args.base_ptr, desc_args.desc.shape, desc_args.desc.strides, NUM_IDX, - BLOCK_N, smem_layout) + gather_block_shape: ttgl.constexpr = [NUM_IDX, BLOCK_N] + gather_desc = amd_tdm.make_tensor_descriptor(desc_args.base_ptr, desc_args.desc.shape, desc_args.desc.strides, + gather_block_shape, smem_layout) num_warps: ttgl.constexpr = ttgl.num_warps() gather_shape: ttgl.constexpr = gather_desc.block_shape idx_base: ttgl.constexpr = ttgl.BlockedLayout([gather_shape[0], 1], @@ -839,8 +831,9 @@ def tl_obj_scatter_amd(desc_args, value, x_offsets, y_offset, NUM_IDX: ttgl.cons # See tl_obj_gather_amd for why the descriptor is recreated with a different block_shape. BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) - scatter_desc = _create_tdm_descriptor(desc_args.base_ptr, desc_args.desc.shape, desc_args.desc.strides, NUM_IDX, - BLOCK_N, smem_layout) + scatter_block_shape: ttgl.constexpr = [NUM_IDX, BLOCK_N] + scatter_desc = amd_tdm.make_tensor_descriptor(desc_args.base_ptr, desc_args.desc.shape, desc_args.desc.strides, + scatter_block_shape, smem_layout) num_warps: ttgl.constexpr = ttgl.num_warps() scatter_shape: ttgl.constexpr = scatter_desc.block_shape idx_base: ttgl.constexpr = ttgl.BlockedLayout([scatter_shape[0], 1], @@ -863,7 +856,7 @@ def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option: @gluon.jit def tl_make_tensor_descriptor_amd(base, shape, strides, block_shape): layout: ttgl.constexpr = get_default_tdm_layout(*block_shape) - desc = _create_tdm_descriptor(base, shape, strides, *block_shape, layout) + desc = amd_tdm.make_tensor_descriptor(base, shape, strides, block_shape, layout) return AMDTensorDescriptorArgs(desc, base) From a8d2fd19a6f784e858b743e660c435513e340870 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 8 Apr 2026 07:44:54 -0700 Subject: [PATCH 10/26] review fixes --- .../triton_to_gluon_translator/translator.py | 32 ++++++------------- .../translator_helpers.py | 6 ++-- 2 files changed, 13 insertions(+), 25 deletions(-) diff --git a/python/triton/tools/triton_to_gluon_translator/translator.py b/python/triton/tools/triton_to_gluon_translator/translator.py index 39d137ecc126..22b4883da307 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator.py +++ b/python/triton/tools/triton_to_gluon_translator/translator.py @@ -109,7 +109,7 @@ def add_expr_rewrites(rewrites: list[RewriteFn]) -> None: class Translator(ReferenceRewriter): tensor_member_match_fns: list[str] = field(default_factory=list) target: str = "nvidia" - _desc_block_shapes: dict = field(default_factory=dict) + _amd_descriptor_vars: set = field(default_factory=set) def __post_init__(self) -> None: import triton @@ -152,20 +152,15 @@ def uncanonicalize_call(self, node: ast.Call, fn_name: str | None) -> ast.Call: return ast.Call(func=new_callable, args=node.args[1:], keywords=node.keywords) def visit_Assign(self, node: ast.Assign) -> ast.AST: + # Track variables assigned from tl.make_tensor_descriptor so we can route + # their load/store/gather/scatter to AMD-specific helpers. if self.target.startswith("gfx") and isinstance(node.value, ast.Call): ref = self.get_reference(node.value.func) if isinstance(node.value.func, (ast.Attribute, ast.Name)) else None if ref is not None and ref[0] is tl.make_tensor_descriptor: for target in node.targets: if isinstance(target, ast.Name): - block_shape_arg = None - for kw in node.value.keywords: - if kw.arg == "block_shape": - block_shape_arg = kw.value - if block_shape_arg is None and len(node.value.args) >= 4: - block_shape_arg = node.value.args[3] - if block_shape_arg is not None: - self._desc_block_shapes[target.id] = ast.unparse(block_shape_arg) + self._amd_descriptor_vars.add(target.id) return self.generic_visit(node) def visit_Call(self, node: ast.Call) -> ast.AST: @@ -183,23 +178,14 @@ def visit_Call(self, node: ast.Call) -> ast.AST: desc_var = node.func.value # Route AMD descriptor ops to TDM-specific helpers. if self.target.startswith("gfx") and isinstance(desc_var, ast.Name): - if desc_var.id in self._desc_block_shapes: + if desc_var.id in self._amd_descriptor_vars: if attr in ["load", "store"]: - # Load/store use the descriptor directly (via AMDTensorDescriptorArgs aggregate). helper_name = f"tl_{attr}_tensor_descriptor_amd" - new_callee = parse_expr(f"helpers.{helper_name}") - node = ast.Call(func=new_callee, args=[node.func.value] + node.args, keywords=node.keywords) - return self.generic_visit(node) - elif attr in ["gather", "scatter"]: - # Gather/scatter reconstruct the descriptor with [num_idx, block_n] - # block_shape since TDM requires it to match the actual operation dims. + else: helper_name = f"tl_obj_{attr}_amd" - new_callee = parse_expr(f"helpers.{helper_name}") - idx_arg = node.args[0] - num_idx_node = parse_expr(f"{ast.unparse(idx_arg)}.shape[0]") - node = ast.Call(func=new_callee, args=[node.func.value] + node.args + [num_idx_node], - keywords=node.keywords) - return self.generic_visit(node) + new_callee = parse_expr(f"helpers.{helper_name}") + node = ast.Call(func=new_callee, args=[node.func.value] + node.args, keywords=node.keywords) + return self.generic_visit(node) # Use AMD-specific helpers for gather/scatter on AMD targets (host descriptor path). if attr in ["gather", "scatter"] and self.target.startswith("gfx"): helper_name = f"tl_obj_{attr}_amd" diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index 4e83f42c1d9d..c3e9f1ab1158 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -802,9 +802,10 @@ def tl_obj_scatter(obj, value, x_offsets, y_offset): @gluon.jit -def tl_obj_gather_amd(desc_args, x_offsets, y_offset, NUM_IDX: ttgl.constexpr): +def tl_obj_gather_amd(desc_args, x_offsets, y_offset): # Triton creates gather descriptors with block_shape=[1, block_n], but TDM hardware # operates on the full batch, requiring block_shape=[num_indices, block_n]. + NUM_IDX: ttgl.constexpr = x_offsets.shape[0] BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) gather_block_shape: ttgl.constexpr = [NUM_IDX, BLOCK_N] @@ -827,8 +828,9 @@ def tl_obj_gather_amd(desc_args, x_offsets, y_offset, NUM_IDX: ttgl.constexpr): @gluon.jit -def tl_obj_scatter_amd(desc_args, value, x_offsets, y_offset, NUM_IDX: ttgl.constexpr): +def tl_obj_scatter_amd(desc_args, value, x_offsets, y_offset): # See tl_obj_gather_amd for why the descriptor is recreated with a different block_shape. + NUM_IDX: ttgl.constexpr = x_offsets.shape[0] BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) scatter_block_shape: ttgl.constexpr = [NUM_IDX, BLOCK_N] From b2f9c161a0284b036945f34164d4172f3062acf6 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 8 Apr 2026 08:00:34 -0700 Subject: [PATCH 11/26] Remove else in tl_dot for NVIDIA path --- .../tools/triton_to_gluon_translator/translator_helpers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index c3e9f1ab1158..d78b9bd28bb1 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -196,8 +196,7 @@ def tl_dot( num_warps: ttgl.constexpr = ttgl.num_warps() if tl_dot_mmav5_supported(a.type, b.type, num_warps, input_precision, allow_tf32, max_num_imprecise_acc): return tl_dot_blackwell(a, b, acc, input_precision, allow_tf32, max_num_imprecise_acc, out_dtype) - else: - return tl_dot_mma_sync(a, b, acc, input_precision, out_dtype) + return tl_dot_mma_sync(a, b, acc, input_precision, out_dtype) @gluon.constexpr_function From 17e6c9e2721ec8337ba0e0a96cff6d45c78904ff Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 8 Apr 2026 08:31:23 -0700 Subject: [PATCH 12/26] Revert "Remove else in tl_dot for NVIDIA path" This reverts commit b2f9c161a0284b036945f34164d4172f3062acf6. --- .../tools/triton_to_gluon_translator/translator_helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index d78b9bd28bb1..c3e9f1ab1158 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -196,7 +196,8 @@ def tl_dot( num_warps: ttgl.constexpr = ttgl.num_warps() if tl_dot_mmav5_supported(a.type, b.type, num_warps, input_precision, allow_tf32, max_num_imprecise_acc): return tl_dot_blackwell(a, b, acc, input_precision, allow_tf32, max_num_imprecise_acc, out_dtype) - return tl_dot_mma_sync(a, b, acc, input_precision, out_dtype) + else: + return tl_dot_mma_sync(a, b, acc, input_precision, out_dtype) @gluon.constexpr_function From 5197b9be33e3a04aba0d90ac102763ab85f42a94 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 8 Apr 2026 09:27:31 -0700 Subject: [PATCH 13/26] Use standard current_target() instead of custom _current_target override --- .../triton_to_gluon_translator/translator.py | 7 ++-- .../translator_helpers.py | 34 +++---------------- 2 files changed, 6 insertions(+), 35 deletions(-) diff --git a/python/triton/tools/triton_to_gluon_translator/translator.py b/python/triton/tools/triton_to_gluon_translator/translator.py index 22b4883da307..22a7076cdcda 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator.py +++ b/python/triton/tools/triton_to_gluon_translator/translator.py @@ -281,11 +281,8 @@ def filter(value: ModuleType | GlobalValue) -> bool: assert reference.mangled_name is not None source = reference.value.mangle_source(source, reference.mangled_name) output += source + "\n\n\n" - header = "\n".join(imports) + "\n" - if target != "nvidia": - header += f'\nhelpers._current_target = helpers._make_target("{target}")\n' - header += "\n" - return header + output + output = "\n".join(imports) + "\n\n" + output + return output def translate_paths(kernel_paths: list[str], target: str = "nvidia") -> str: diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index c3e9f1ab1158..e3a8730f99ce 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -17,6 +17,10 @@ from triton.experimental.gluon.language.nvidia.blackwell import tma as tma_blackwell from triton.experimental.gluon.language.nvidia.hopper import fence_async_shared, mbarrier, tma from triton.experimental.gluon.language.amd.gfx1250 import wmma as amd_wmma + +# hack to workaround limited dependencies tracking. +# TODO: fix this by pulling imports into the generated file. +from triton.language.target_info import current_target # noqa: F401 from triton.experimental.gluon.language.amd.gfx1250 import tdm as amd_tdm from triton.experimental.gluon.language.amd.cdna3 import mfma as amd_mfma @@ -1021,33 +1025,3 @@ def convert_to_expand_dims_layout(value, expand_dims: list[int]): return ttgl.convert_layout(value, layout) -# Module-level target, set by the translator via _make_target(). -# Falls back to the active driver's target if not set. -_current_target = None - - -def current_target(): - if _current_target is not None: - return _current_target - from triton.runtime import driver - - try: - active_driver = driver.active - except RuntimeError: - return None - return active_driver.get_current_target() - - -current_target.__triton_builtin__ = True - - -def _make_target(arch): - """Construct a GPUTarget from an architecture string (e.g. 'gfx1250', 'nvidia').""" - if arch.startswith("gfx"): - from triton.backends.amd.compiler import GPUTarget - warp_size = 32 if int(arch[3:-2]) >= 10 else 64 - return GPUTarget("hip", arch, warp_size) - return None - - -_make_target.__triton_builtin__ = True From f2708dd82db71b0b0ee3d45fe830cc3e3d33c0ef Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 8 Apr 2026 09:44:57 -0700 Subject: [PATCH 14/26] Use isinstance dispatch instead of translator routing for AMD descriptor ops Made-with: Cursor --- .../triton_to_gluon_translator/translator.py | 32 +---- .../translator_helpers.py | 121 +++++++++--------- 2 files changed, 60 insertions(+), 93 deletions(-) diff --git a/python/triton/tools/triton_to_gluon_translator/translator.py b/python/triton/tools/triton_to_gluon_translator/translator.py index 22a7076cdcda..600aa57bb279 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator.py +++ b/python/triton/tools/triton_to_gluon_translator/translator.py @@ -109,7 +109,6 @@ def add_expr_rewrites(rewrites: list[RewriteFn]) -> None: class Translator(ReferenceRewriter): tensor_member_match_fns: list[str] = field(default_factory=list) target: str = "nvidia" - _amd_descriptor_vars: set = field(default_factory=set) def __post_init__(self) -> None: import triton @@ -151,18 +150,6 @@ def uncanonicalize_call(self, node: ast.Call, fn_name: str | None) -> ast.Call: new_callable = ast.Attribute(value, fn_name, ctx=ast.Load()) return ast.Call(func=new_callable, args=node.args[1:], keywords=node.keywords) - def visit_Assign(self, node: ast.Assign) -> ast.AST: - # Track variables assigned from tl.make_tensor_descriptor so we can route - # their load/store/gather/scatter to AMD-specific helpers. - if self.target.startswith("gfx") and isinstance(node.value, ast.Call): - ref = self.get_reference(node.value.func) if isinstance(node.value.func, - (ast.Attribute, ast.Name)) else None - if ref is not None and ref[0] is tl.make_tensor_descriptor: - for target in node.targets: - if isinstance(target, ast.Name): - self._amd_descriptor_vars.add(target.id) - return self.generic_visit(node) - def visit_Call(self, node: ast.Call) -> ast.AST: node, canonicalized = self.canonicalize_call(node) ref = self.get_reference(node.func) @@ -174,24 +161,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST: "gather", "scatter", ]: - attr = node.func.attr - desc_var = node.func.value - # Route AMD descriptor ops to TDM-specific helpers. - if self.target.startswith("gfx") and isinstance(desc_var, ast.Name): - if desc_var.id in self._amd_descriptor_vars: - if attr in ["load", "store"]: - helper_name = f"tl_{attr}_tensor_descriptor_amd" - else: - helper_name = f"tl_obj_{attr}_amd" - new_callee = parse_expr(f"helpers.{helper_name}") - node = ast.Call(func=new_callee, args=[node.func.value] + node.args, keywords=node.keywords) - return self.generic_visit(node) - # Use AMD-specific helpers for gather/scatter on AMD targets (host descriptor path). - if attr in ["gather", "scatter"] and self.target.startswith("gfx"): - helper_name = f"tl_obj_{attr}_amd" - else: - helper_name = f"tl_obj_{attr}" - new_callee = parse_expr(f"helpers.{helper_name}") + new_callee = parse_expr(f"helpers.tl_obj_{node.func.attr}") node = ast.Call(func=new_callee, args=[node.func.value] + node.args, keywords=node.keywords) return self.generic_visit(node) value, _, _ = ref diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index e3a8730f99ce..4c57fd9d6749 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -713,27 +713,16 @@ class AMDTensorDescriptorArgs: @gluon.jit -def tl_load_tensor_descriptor_amd(obj, offsets): - if isinstance(obj, AMDTensorDescriptorArgs): - desc = obj.desc - else: - # Real TDM descriptor from convert_host_descriptor (passed as kernel arg). - desc = obj +def _load_tdm(desc, offsets): smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout) amd_tdm.async_load(desc, offsets, smem) amd_tdm.async_wait(0) ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) - out = smem.load(ret_layout) - return out + return smem.load(ret_layout) @gluon.jit -def tl_store_tensor_descriptor_amd(obj, offsets, value): - if isinstance(obj, AMDTensorDescriptorArgs): - desc = obj.desc - else: - # Real TDM descriptor from convert_host_descriptor (passed as kernel arg). - desc = obj +def _store_tdm(desc, offsets, value): smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value) amd_tdm.async_store(desc, offsets, smem) amd_tdm.async_wait(0) @@ -746,8 +735,10 @@ def tl_store_tensor_descriptor_amd(obj, offsets, value): def tl_obj_store(obj, offsets, value): if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): return tl_store_tensor_descriptor(obj, offsets, value) + elif isinstance(obj, AMDTensorDescriptorArgs): + _store_tdm(obj.desc, offsets, value) elif isinstance(obj, amd_tdm.tensor_descriptor): - return tl_store_tensor_descriptor_amd(obj, offsets, value) + _store_tdm(obj, offsets, value) else: return obj.store(offsets, value) @@ -756,57 +747,16 @@ def tl_obj_store(obj, offsets, value): def tl_obj_load(obj, offsets): if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): return tl_load_tensor_descriptor(obj, offsets) + elif isinstance(obj, AMDTensorDescriptorArgs): + return _load_tdm(obj.desc, offsets) elif isinstance(obj, amd_tdm.tensor_descriptor): - return tl_load_tensor_descriptor_amd(obj, offsets) + return _load_tdm(obj, offsets) else: return obj.load(offsets) @gluon.jit -def tl_obj_gather(obj, x_offsets, y_offset): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - desc = obj - desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] - alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout) - bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) - mbarrier.init(bar, count=1) - x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( - 0, - ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), - ) - x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) - mbarrier.expect(bar, x_offsets.shape[0] * obj.block_type.nbytes) - tma_blackwell.async_gather(desc, x_offsets, y_offset, bar, alloc) - mbarrier.wait(bar, phase=0) - mbarrier.invalidate(bar) - # Load from shared memory into a register tensor using a reasonable default layout - ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) - out = alloc.load(ret_layout) - return out - else: - return obj.gather(x_offsets, y_offset) - - -@gluon.jit -def tl_obj_scatter(obj, value, x_offsets, y_offset): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - desc = obj - desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] - alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout, value) - fence_async_shared() - x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( - 0, - ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), - ) - x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) - tma_blackwell.async_scatter(desc, x_offsets, y_offset, alloc) - tma.store_wait(0) - else: - obj.scatter(value, x_offsets, y_offset) - - -@gluon.jit -def tl_obj_gather_amd(desc_args, x_offsets, y_offset): +def _gather_tdm(desc_args, x_offsets, y_offset): # Triton creates gather descriptors with block_shape=[1, block_n], but TDM hardware # operates on the full batch, requiring block_shape=[num_indices, block_n]. NUM_IDX: ttgl.constexpr = x_offsets.shape[0] @@ -832,8 +782,8 @@ def tl_obj_gather_amd(desc_args, x_offsets, y_offset): @gluon.jit -def tl_obj_scatter_amd(desc_args, value, x_offsets, y_offset): - # See tl_obj_gather_amd for why the descriptor is recreated with a different block_shape. +def _scatter_tdm(desc_args, value, x_offsets, y_offset): + # See _gather_tdm for why the descriptor is recreated with a different block_shape. NUM_IDX: ttgl.constexpr = x_offsets.shape[0] BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) @@ -853,6 +803,53 @@ def tl_obj_scatter_amd(desc_args, value, x_offsets, y_offset): amd_tdm.async_wait(0) +@gluon.jit +def tl_obj_gather(obj, x_offsets, y_offset): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + desc = obj + desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] + alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( + 0, + ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), + ) + x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) + mbarrier.expect(bar, x_offsets.shape[0] * obj.block_type.nbytes) + tma_blackwell.async_gather(desc, x_offsets, y_offset, bar, alloc) + mbarrier.wait(bar, phase=0) + mbarrier.invalidate(bar) + # Load from shared memory into a register tensor using a reasonable default layout + ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) + out = alloc.load(ret_layout) + return out + elif isinstance(obj, AMDTensorDescriptorArgs): + return _gather_tdm(obj, x_offsets, y_offset) + else: + return obj.gather(x_offsets, y_offset) + + +@gluon.jit +def tl_obj_scatter(obj, value, x_offsets, y_offset): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + desc = obj + desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] + alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout, value) + fence_async_shared() + x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( + 0, + ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), + ) + x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) + tma_blackwell.async_scatter(desc, x_offsets, y_offset, alloc) + tma.store_wait(0) + elif isinstance(obj, AMDTensorDescriptorArgs): + _scatter_tdm(obj, value, x_offsets, y_offset) + else: + obj.scatter(value, x_offsets, y_offset) + + @gluon.jit def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option: ttgl.constexpr = "zero"): layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for(block_shape, base.dtype.element_ty) From cc3823e2fb403c287d2dd93ac0a81cc75375227f Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 8 Apr 2026 09:49:45 -0700 Subject: [PATCH 15/26] Rename _load/_store/_gather/_scatter_tdm to tl_obj_*_amd naming convention Made-with: Cursor --- .../translator_helpers.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index 4c57fd9d6749..fadd1c803aed 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -713,7 +713,7 @@ class AMDTensorDescriptorArgs: @gluon.jit -def _load_tdm(desc, offsets): +def tl_obj_load_amd(desc, offsets): smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout) amd_tdm.async_load(desc, offsets, smem) amd_tdm.async_wait(0) @@ -722,7 +722,7 @@ def _load_tdm(desc, offsets): @gluon.jit -def _store_tdm(desc, offsets, value): +def tl_obj_store_amd(desc, offsets, value): smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value) amd_tdm.async_store(desc, offsets, smem) amd_tdm.async_wait(0) @@ -736,9 +736,9 @@ def tl_obj_store(obj, offsets, value): if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): return tl_store_tensor_descriptor(obj, offsets, value) elif isinstance(obj, AMDTensorDescriptorArgs): - _store_tdm(obj.desc, offsets, value) + tl_obj_store_amd(obj.desc, offsets, value) elif isinstance(obj, amd_tdm.tensor_descriptor): - _store_tdm(obj, offsets, value) + tl_obj_store_amd(obj, offsets, value) else: return obj.store(offsets, value) @@ -748,15 +748,15 @@ def tl_obj_load(obj, offsets): if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): return tl_load_tensor_descriptor(obj, offsets) elif isinstance(obj, AMDTensorDescriptorArgs): - return _load_tdm(obj.desc, offsets) + return tl_obj_load_amd(obj.desc, offsets) elif isinstance(obj, amd_tdm.tensor_descriptor): - return _load_tdm(obj, offsets) + return tl_obj_load_amd(obj, offsets) else: return obj.load(offsets) @gluon.jit -def _gather_tdm(desc_args, x_offsets, y_offset): +def tl_obj_gather_amd(desc_args, x_offsets, y_offset): # Triton creates gather descriptors with block_shape=[1, block_n], but TDM hardware # operates on the full batch, requiring block_shape=[num_indices, block_n]. NUM_IDX: ttgl.constexpr = x_offsets.shape[0] @@ -782,8 +782,8 @@ def _gather_tdm(desc_args, x_offsets, y_offset): @gluon.jit -def _scatter_tdm(desc_args, value, x_offsets, y_offset): - # See _gather_tdm for why the descriptor is recreated with a different block_shape. +def tl_obj_scatter_amd(desc_args, value, x_offsets, y_offset): + # See tl_obj_gather_amd for why the descriptor is recreated with a different block_shape. NUM_IDX: ttgl.constexpr = x_offsets.shape[0] BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) @@ -825,7 +825,7 @@ def tl_obj_gather(obj, x_offsets, y_offset): out = alloc.load(ret_layout) return out elif isinstance(obj, AMDTensorDescriptorArgs): - return _gather_tdm(obj, x_offsets, y_offset) + return tl_obj_gather_amd(obj, x_offsets, y_offset) else: return obj.gather(x_offsets, y_offset) @@ -845,7 +845,7 @@ def tl_obj_scatter(obj, value, x_offsets, y_offset): tma_blackwell.async_scatter(desc, x_offsets, y_offset, alloc) tma.store_wait(0) elif isinstance(obj, AMDTensorDescriptorArgs): - _scatter_tdm(obj, value, x_offsets, y_offset) + tl_obj_scatter_amd(obj, value, x_offsets, y_offset) else: obj.scatter(value, x_offsets, y_offset) From ee0db99f1fb27a8dab35722fb642de0183c84261 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 8 Apr 2026 09:50:14 -0700 Subject: [PATCH 16/26] pre-commit run --- .../tools/triton_to_gluon_translator/translator_helpers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index fadd1c803aed..554add5fda4f 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -1020,5 +1020,3 @@ def build_expand_dims_layout(shape, expand_dims, num_warps): def convert_to_expand_dims_layout(value, expand_dims: list[int]): layout: ttgl.constexpr = build_expand_dims_layout(value.shape, expand_dims, ttgl.num_warps()) return ttgl.convert_layout(value, layout) - - From 00964d95858fb5b4074af8dfc889ceebd1e188ef Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 8 Apr 2026 09:56:01 -0700 Subject: [PATCH 17/26] Add static_assert for gfx1250 target in tl_make_tensor_descriptor_amd Made-with: Cursor --- .../tools/triton_to_gluon_translator/translator_helpers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index 554add5fda4f..94fd8693fc24 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -858,6 +858,7 @@ def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option: @gluon.jit def tl_make_tensor_descriptor_amd(base, shape, strides, block_shape): + ttgl.static_assert(_is_gfx1250(current_target()), "tl_make_tensor_descriptor_amd requires gfx1250 target") layout: ttgl.constexpr = get_default_tdm_layout(*block_shape) desc = amd_tdm.make_tensor_descriptor(base, shape, strides, block_shape, layout) return AMDTensorDescriptorArgs(desc, base) From cf02165e1a7c23dc8f39f311158ffba900ca16d0 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Fri, 10 Apr 2026 10:52:32 -0700 Subject: [PATCH 18/26] Use current target instead of parametrizing all targets in translator tests Avoid pytest-xdist load imbalance by detecting the available target at runtime via current_target() rather than parametrizing over every target and skipping the unavailable ones. Made-with: Cursor --- .../test/unit/tools/test_triton_to_gluon.py | 111 ++++++------------ python/triton/_internal_testing.py | 4 + 2 files changed, 42 insertions(+), 73 deletions(-) diff --git a/python/test/unit/tools/test_triton_to_gluon.py b/python/test/unit/tools/test_triton_to_gluon.py index 73d1fdd4148e..f0bc99b2e401 100644 --- a/python/test/unit/tools/test_triton_to_gluon.py +++ b/python/test/unit/tools/test_triton_to_gluon.py @@ -8,31 +8,14 @@ from triton.tools.triton_to_gluon_translator.translator import convert_triton_to_gluon from triton.tools.triton_to_gluon_translator.translator_helpers import convert_host_descriptor -from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda, is_hip_gfx1250, is_hip_cdna3, is_hip_cdna4 +from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda, is_hip_gfx1250, is_hip_cdna3_or_newer +from triton.language.target_info import current_target -_all_targets = { - "nvidia": is_cuda, - "gfx1250": is_hip_gfx1250, - "gfx942": is_hip_cdna3, - "gfx950": is_hip_cdna4, -} -_dot_targets = { - "nvidia": is_blackwell, - "gfx1250": is_hip_gfx1250, - "gfx942": is_hip_cdna3, - "gfx950": is_hip_cdna4, -} - - -def _skip_unless_target(target, targets=_all_targets): - """Skip test if the required hardware for the given target is not available. - Use _dot_targets for dot tests (Blackwell on NVIDIA), _descriptor_targets for descriptor tests.""" - if not targets[target](): - pytest.skip(f"Requires {target}") - - -def convert_kernel(kernel, kernel_name, tmp_path, target="nvidia"): +def convert_kernel(kernel, kernel_name, tmp_path, target=None): + if target is None: + t = current_target() + target = "nvidia" if t.backend == "cuda" else t.arch converted = convert_triton_to_gluon([kernel], target=target) # Write converted kernel to a file so @gluon.jit can retrieve source @@ -57,10 +40,8 @@ def add_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr): tl.store(out_ptr + offsets, x + y) -@pytest.mark.parametrize("target", _all_targets.keys()) -def test_simple_kernel(tmp_path, target): - _skip_unless_target(target) - kernel = convert_kernel(add_kernel, "add_kernel", tmp_path, target=target) +def test_simple_kernel(tmp_path): + kernel = convert_kernel(add_kernel, "add_kernel", tmp_path) n = 1024 BLOCK = 128 @@ -92,10 +73,10 @@ def matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M: tl.constexpr, BLOCK_N: tl.c impl_matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M, BLOCK_N, BLOCK_K) -@pytest.mark.parametrize("target", _dot_targets.keys()) -def test_triton_to_gluon_dot_minimal(tmp_path, target): - _skip_unless_target(target, _dot_targets) - kernel = convert_kernel(matmul_tile_kernel, "matmul_tile_kernel", tmp_path, target=target) +def test_triton_to_gluon_dot_minimal(tmp_path): + if not (is_blackwell() or is_hip_cdna3_or_newer() or is_hip_gfx1250()): + pytest.skip("Requires Blackwell, CDNA3+, or gfx1250") + kernel = convert_kernel(matmul_tile_kernel, "matmul_tile_kernel", tmp_path) M, N, K = 128, 128, 128 a = torch.randn((M, K), device="cuda", dtype=torch.float16) b = torch.randn((K, N), device="cuda", dtype=torch.float16) @@ -153,8 +134,9 @@ def matmul_kernel( # @pytest.mark.parametrize("dtype_dst_str", ["float32"]) @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES", [(128, 128, 64, 1)]) @pytest.mark.parametrize("NUM_WARPS", [4]) -@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, tmp_path): + if not (is_blackwell() or is_hip_cdna3_or_newer() or is_hip_gfx1250()): + pytest.skip("Requires Blackwell, CDNA3+, or gfx1250") device = "cuda" M, N, K = 1024, 512, 256 torch.manual_seed(42) @@ -183,17 +165,16 @@ def descriptor_store_kernel(desc, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, desc.store([0, 0], tile) -_descriptor_targets = { - "nvidia": is_hopper_or_newer, - "gfx1250": is_hip_gfx1250, -} +def _skip_unless_descriptor_target(): + if is_cuda() and not is_hopper_or_newer(): + pytest.skip("Requires Hopper+") + elif not is_cuda() and not is_hip_gfx1250(): + pytest.skip("Requires descriptor support") -@pytest.mark.parametrize("target", _descriptor_targets.keys()) -def test_triton_to_gluon_descriptor_roundtrip(tmp_path, target): - if not _descriptor_targets[target](): - pytest.skip(f"Requires {target} with descriptor support") - kernel = convert_kernel(descriptor_store_kernel, "descriptor_store_kernel", tmp_path, target=target) +def test_triton_to_gluon_descriptor_roundtrip(tmp_path): + _skip_unless_descriptor_target() + kernel = convert_kernel(descriptor_store_kernel, "descriptor_store_kernel", tmp_path) M = N = 64 y = torch.zeros((M, N), device="cuda", dtype=torch.float16) @@ -215,11 +196,9 @@ def descriptor_copy_kernel(in_desc, out_desc, BLOCK_M: tl.constexpr, BLOCK_N: tl out_desc.store([0, 0], tile) -@pytest.mark.parametrize("target", _descriptor_targets.keys()) -def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path, target): - if not _descriptor_targets[target](): - pytest.skip(f"Requires {target} with descriptor support") - kernel = convert_kernel(descriptor_copy_kernel, "descriptor_copy_kernel", tmp_path, target=target) +def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path): + _skip_unless_descriptor_target() + kernel = convert_kernel(descriptor_copy_kernel, "descriptor_copy_kernel", tmp_path) M = N = 64 x = torch.ones((M, N), device="cuda", dtype=torch.float16) * 3.0 @@ -256,13 +235,9 @@ def make_tensor_descriptor_copy_kernel(x_ptr, y_ptr, M, N, BLOCK_M: tl.constexpr out_desc.store([0, 0], tile) -# Parametrized over _descriptor_targets: tests tl.make_tensor_descriptor translation -# for both NVIDIA TMA (Hopper+) and AMD TDM (gfx1250). -@pytest.mark.parametrize("target", _descriptor_targets.keys()) -def test_triton_to_gluon_make_tensor_descriptor(tmp_path, target, with_allocator): - _skip_unless_target(target, _descriptor_targets) - kernel = convert_kernel(make_tensor_descriptor_copy_kernel, "make_tensor_descriptor_copy_kernel", tmp_path, - target=target) +def test_triton_to_gluon_make_tensor_descriptor(tmp_path, with_allocator): + _skip_unless_descriptor_target() + kernel = convert_kernel(make_tensor_descriptor_copy_kernel, "make_tensor_descriptor_copy_kernel", tmp_path) M = N = 64 x = torch.randn((M, N), device="cuda", dtype=torch.float16) @@ -294,10 +269,8 @@ def reshape_trans_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr, @pytest.mark.parametrize("TRANS_KIND", ["trans_method", "tl_trans_separate", "tl_trans_tuple", "tl_trans"]) -@pytest.mark.parametrize("target", _all_targets.keys()) -def test_triton_reshape_trans(tmp_path, TRANS_KIND, target): - _skip_unless_target(target) - kernel = convert_kernel(reshape_trans_kernel, "reshape_trans_kernel", tmp_path, target=target) +def test_triton_reshape_trans(tmp_path, TRANS_KIND): + kernel = convert_kernel(reshape_trans_kernel, "reshape_trans_kernel", tmp_path) n = 1024 BLOCK = 256 @@ -326,10 +299,8 @@ def split_kernel(x_ptr, out_ptr): tl.store(p, a) -@pytest.mark.parametrize("target", _all_targets.keys()) -def test_split(tmp_path, target): - _skip_unless_target(target) - kernel = convert_kernel(split_kernel, "split_kernel", tmp_path, target=target) +def test_split(tmp_path): + kernel = convert_kernel(split_kernel, "split_kernel", tmp_path) n = 1024 x = torch.randn(2 * n, device="cuda", dtype=torch.float32) @@ -377,10 +348,8 @@ def reduce_to_scalar_kernel(out_ptr): tl.store(out_ptr, x) -@pytest.mark.parametrize("target", _all_targets.keys()) -def test_reduce_to_scalar(tmp_path, target): - _skip_unless_target(target) - kernel = convert_kernel(reduce_to_scalar_kernel, "reduce_to_scalar_kernel", tmp_path, target=target) +def test_reduce_to_scalar(tmp_path): + kernel = convert_kernel(reduce_to_scalar_kernel, "reduce_to_scalar_kernel", tmp_path) grid = (1, ) out = torch.empty((1, ), device="cuda", dtype=torch.int32) @@ -442,10 +411,8 @@ def atomic_add_kernel(out_ptr, BLOCK: tl.constexpr): tl.atomic_add(out_ptr + idx, idx, mask=scalar_mask, sem="release", scope="cta") -@pytest.mark.parametrize("target", _all_targets.keys()) -def test_atomic_add(tmp_path, target): - _skip_unless_target(target) - kernel = convert_kernel(atomic_add_kernel, "atomic_add_kernel", tmp_path, target=target) +def test_atomic_add(tmp_path): + kernel = convert_kernel(atomic_add_kernel, "atomic_add_kernel", tmp_path) block = 32 * 4 ref = torch.zeros((block, ), device="cuda") @@ -468,10 +435,8 @@ def cat_kernel(x_ptr, y_ptr, out_ptr, BLOCK: tl.constexpr): tl.store(out_ptr + tl.arange(0, 2 * BLOCK), z) -@pytest.mark.parametrize("target", _all_targets.keys()) -def test_cat(tmp_path, target): - _skip_unless_target(target) - kernel = convert_kernel(cat_kernel, "cat_kernel", tmp_path, target=target) +def test_cat(tmp_path): + kernel = convert_kernel(cat_kernel, "cat_kernel", tmp_path) BLOCK = 256 x = torch.randn(BLOCK, device="cuda", dtype=torch.float32) diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index 8c52bb81d877..ba32f860e002 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -102,6 +102,10 @@ def is_hip_gfx1250(): return target is not None and target.backend == 'hip' and 'gfx1250' in target.arch +def is_hip_cdna3_or_newer(): + return is_hip_cdna3() or is_hip_cdna4() + + def is_hip_cdna(): return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4() From 28095a798bbeca57c89fb5a5adbe78a98ec6190c Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Sat, 11 Apr 2026 02:44:22 -0700 Subject: [PATCH 19/26] Only allow blackwell/cdna4/gfx1250 for test_simple_matmul --- python/test/unit/tools/test_triton_to_gluon.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/test/unit/tools/test_triton_to_gluon.py b/python/test/unit/tools/test_triton_to_gluon.py index f0bc99b2e401..4a68b885900a 100644 --- a/python/test/unit/tools/test_triton_to_gluon.py +++ b/python/test/unit/tools/test_triton_to_gluon.py @@ -8,7 +8,7 @@ from triton.tools.triton_to_gluon_translator.translator import convert_triton_to_gluon from triton.tools.triton_to_gluon_translator.translator_helpers import convert_host_descriptor -from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda, is_hip_gfx1250, is_hip_cdna3_or_newer +from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda, is_hip_cdna4, is_hip_gfx1250, is_hip_cdna3_or_newer from triton.language.target_info import current_target @@ -135,8 +135,8 @@ def matmul_kernel( # @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES", [(128, 128, 64, 1)]) @pytest.mark.parametrize("NUM_WARPS", [4]) def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, tmp_path): - if not (is_blackwell() or is_hip_cdna3_or_newer() or is_hip_gfx1250()): - pytest.skip("Requires Blackwell, CDNA3+, or gfx1250") + if not (is_blackwell() or is_hip_cdna4() or is_hip_gfx1250()): + pytest.skip("Requires Blackwell, CDNA4, or gfx1250") device = "cuda" M, N, K = 1024, 512, 256 torch.manual_seed(42) From 8a42e7fed03a7e569bd3c21554f1fd0b46cf016d Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Mon, 13 Apr 2026 23:30:09 -0700 Subject: [PATCH 20/26] Introduce TranslatorTarget StrEnum for translator hardware abstraction - Add TranslatorTarget StrEnum with is_amd property and tensor_descriptor_import dispatch, accepting any gfx* string via _missing_() for forward-compat with new AMD architectures. - Replace raw target strings and _is_amd_target() in SliceRewriter and Translator with the enum. - Unify tl_make_tensor_descriptor and tl_make_tensor_descriptor_amd into a single helper that dispatches via current_target() at runtime, matching how tl_dot already works. Made-with: Cursor --- .../test/unit/tools/test_triton_to_gluon.py | 5 +- .../slice_kernel.py | 52 +++++++++++++++---- .../triton_to_gluon_translator/translator.py | 16 +++--- .../translator_helpers.py | 18 +++---- 4 files changed, 59 insertions(+), 32 deletions(-) diff --git a/python/test/unit/tools/test_triton_to_gluon.py b/python/test/unit/tools/test_triton_to_gluon.py index 4a68b885900a..feac1d39ee90 100644 --- a/python/test/unit/tools/test_triton_to_gluon.py +++ b/python/test/unit/tools/test_triton_to_gluon.py @@ -8,6 +8,7 @@ from triton.tools.triton_to_gluon_translator.translator import convert_triton_to_gluon from triton.tools.triton_to_gluon_translator.translator_helpers import convert_host_descriptor +from triton.tools.triton_to_gluon_translator.slice_kernel import TranslatorTarget from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda, is_hip_cdna4, is_hip_gfx1250, is_hip_cdna3_or_newer from triton.language.target_info import current_target @@ -15,7 +16,7 @@ def convert_kernel(kernel, kernel_name, tmp_path, target=None): if target is None: t = current_target() - target = "nvidia" if t.backend == "cuda" else t.arch + target = TranslatorTarget("nvidia" if t.backend == "cuda" else t.arch) converted = convert_triton_to_gluon([kernel], target=target) # Write converted kernel to a file so @gluon.jit can retrieve source @@ -464,7 +465,7 @@ def gather_scatter_roundtrip_kernel(out_ptr, in_ptr, idx_ptr, X: tl.constexpr, Y @pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires gfx1250") def test_gather_scatter_roundtrip(tmp_path): kernel = convert_kernel(gather_scatter_roundtrip_kernel, "gather_scatter_roundtrip_kernel", tmp_path, - target="gfx1250") + target=TranslatorTarget.GFX1250) X, Y, BLOCK_X, BLOCK_Y = 64, 64, 8, 64 inp = torch.arange(X * Y, device="cuda", dtype=torch.float16).reshape(X, Y) diff --git a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py index 337939681c9b..df29b9b5bc74 100644 --- a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py +++ b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py @@ -13,6 +13,7 @@ from collections import OrderedDict from collections.abc import Sequence from dataclasses import dataclass, field +from enum import StrEnum from pathlib import Path from types import BuiltinFunctionType, FunctionType, ModuleType from typing import Any, Callable, TypeAlias @@ -30,6 +31,41 @@ logger = logging.getLogger(__name__) +class TranslatorTarget(StrEnum): + """Target architecture for the Triton-to-Gluon translator. + + Known targets are listed as explicit members for discoverability. + Unknown ``gfx*`` strings are accepted via ``_missing_()`` so that + new AMD architectures work without adding an enum member. + """ + + NVIDIA = "nvidia" + # AMD targets currently exercised by the translator test suite: + GFX1250 = "gfx1250" + GFX942 = "gfx942" + GFX950 = "gfx950" + + @classmethod + def _missing_(cls, value: object) -> "TranslatorTarget | None": + """Allow any ``gfx*`` string as a valid AMD target.""" + if isinstance(value, str) and value.startswith("gfx"): + obj = str.__new__(cls, value) + obj._value_ = value + return obj + return None + + @property + def is_amd(self) -> bool: + return self != TranslatorTarget.NVIDIA + + @property + def tensor_descriptor_import(self) -> str: + """Return the import statement for the target's tensor descriptor module.""" + if self.is_amd: + return "from triton.experimental.gluon.language.amd.gfx1250.tdm import tensor_descriptor" + return "from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor" + + @dataclass class GlobalVariable: name: str @@ -521,7 +557,7 @@ class SliceRewriter(ReferenceRewriter): translate_to_gluon: bool = False inline_helpers: ordered_set[str] = field(default_factory=ordered_set[str]) cvt_context: list[bool] = field(default_factory=lambda: [False]) - target: str = "nvidia" + target: TranslatorTarget = TranslatorTarget.NVIDIA def __post_init__(self) -> None: # Special rules for sugaring imports. @@ -552,9 +588,6 @@ def emit_reference(self, node: ast.AST) -> Any: return node raise e - def _is_amd_target(self) -> bool: - return self.target.startswith("gfx") - def visit_Attribute(self, node: ast.Attribute) -> ast.AST: if not self.translate_to_gluon: return super().visit_Attribute(node) @@ -566,10 +599,7 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST: self.imports.add("import triton.experimental.gluon._runtime as gluon_runtime") new_node = parse_expr("gluon_runtime.GluonJITFunction") elif value is tl.tensor_descriptor: - if self._is_amd_target(): - self.imports.add("from triton.experimental.gluon.language.amd.gfx1250.tdm import tensor_descriptor") - else: - self.imports.add("from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor") + self.imports.add(self.target.tensor_descriptor_import) new_node = ast.Name(id="tensor_descriptor", ctx=ast.Load()) return new_node @@ -716,7 +746,7 @@ def slice_kernel( leaf_paths: list[str] | None = None, translate_to_gluon: bool = False, ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, - target: str = "nvidia", + target: TranslatorTarget = TranslatorTarget.NVIDIA, ) -> str: base_values: list[GlobalValue] = [get_base_value(root_path) for root_path in root_paths] base_value_ids: set[int] = set() @@ -816,7 +846,7 @@ def slice_kernel_from_trace( translate_to_gluon: bool, extra_modules: dict[str, str], ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, - target: str = "nvidia", + target: TranslatorTarget = TranslatorTarget.NVIDIA, ) -> str: module_remap: dict[str, str] = {} for name, path in extra_modules.items(): @@ -864,7 +894,7 @@ def main( translate_to_gluon: bool = False, output_path: str = "/tmp/reference.py", ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, - target: str = "nvidia", + target: TranslatorTarget = TranslatorTarget.NVIDIA, ) -> None: output = slice_kernel( root_paths, diff --git a/python/triton/tools/triton_to_gluon_translator/translator.py b/python/triton/tools/triton_to_gluon_translator/translator.py index 600aa57bb279..92affdd4cbfe 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator.py +++ b/python/triton/tools/triton_to_gluon_translator/translator.py @@ -14,6 +14,7 @@ GlobalValue, ReferenceRewriter, RewriteFn, + TranslatorTarget, add_sugar_rewrites, find_references, get_base_value, @@ -108,7 +109,7 @@ def add_expr_rewrites(rewrites: list[RewriteFn]) -> None: @dataclass class Translator(ReferenceRewriter): tensor_member_match_fns: list[str] = field(default_factory=list) - target: str = "nvidia" + target: TranslatorTarget = TranslatorTarget.NVIDIA def __post_init__(self) -> None: import triton @@ -165,9 +166,6 @@ def visit_Call(self, node: ast.Call) -> ast.AST: node = ast.Call(func=new_callee, args=[node.func.value] + node.args, keywords=node.keywords) return self.generic_visit(node) value, _, _ = ref - if value is tl.make_tensor_descriptor and self.target.startswith("gfx"): - node.func = parse_expr("helpers.tl_make_tensor_descriptor_amd") - node.keywords = [kw for kw in node.keywords if kw.arg != "padding_option"] if value in [tl.reshape, tl.ravel]: node.keywords = [kw for kw in node.keywords if kw.arg != "can_reorder"] elif value is tl.split: @@ -202,7 +200,7 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.AST: return self.generic_visit(node) -def translate_kernels(kernels: list[GlobalValue], target: str = "nvidia") -> str: +def translate_kernels(kernels: list[GlobalValue], target: TranslatorTarget = TranslatorTarget.NVIDIA) -> str: def filter(value: ModuleType | GlobalValue) -> bool: if isinstance(value, ModuleType): @@ -255,12 +253,12 @@ def filter(value: ModuleType | GlobalValue) -> bool: return output -def translate_paths(kernel_paths: list[str], target: str = "nvidia") -> str: +def translate_paths(kernel_paths: list[str], target: TranslatorTarget = TranslatorTarget.NVIDIA) -> str: kernels = [get_base_value(kernel_path) for kernel_path in kernel_paths] return translate_kernels(kernels, target=target) -def convert_triton_to_gluon(src: list[JITCallable], target: str = "nvidia") -> str: +def convert_triton_to_gluon(src: list[JITCallable], target: TranslatorTarget = TranslatorTarget.NVIDIA) -> str: kernels = [ GlobalValue.wrap( kernel, @@ -271,7 +269,7 @@ def convert_triton_to_gluon(src: list[JITCallable], target: str = "nvidia") -> s return translate_kernels(kernels, target=target) -def main(kernels: list[str], output_path: str, target: str = "nvidia") -> None: +def main(kernels: list[str], output_path: str, target: TranslatorTarget = TranslatorTarget.NVIDIA) -> None: output = translate_paths(kernels, target=target) with open(output_path, "w") as f: f.write(output) @@ -283,7 +281,7 @@ def _main_cli() -> None: parser.add_argument("--output-path", required=True, help="Path to write the translated source.") parser.add_argument("--target", default="nvidia", help="Target architecture (e.g. nvidia, amd_gfx1250).") args = parser.parse_args() - main(args.kernels, args.output_path, target=args.target) + main(args.kernels, args.output_path, target=TranslatorTarget(args.target)) if __name__ == "__main__": diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index 94fd8693fc24..9d930d19cd6a 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -852,16 +852,14 @@ def tl_obj_scatter(obj, value, x_offsets, y_offset): @gluon.jit def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option: ttgl.constexpr = "zero"): - layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for(block_shape, base.dtype.element_ty) - return tma.make_tensor_descriptor(base, shape, strides, block_shape, layout, padding_option) - - -@gluon.jit -def tl_make_tensor_descriptor_amd(base, shape, strides, block_shape): - ttgl.static_assert(_is_gfx1250(current_target()), "tl_make_tensor_descriptor_amd requires gfx1250 target") - layout: ttgl.constexpr = get_default_tdm_layout(*block_shape) - desc = amd_tdm.make_tensor_descriptor(base, shape, strides, block_shape, layout) - return AMDTensorDescriptorArgs(desc, base) + target: ttgl.constexpr = current_target() + if _is_gfx1250(target): + layout: ttgl.constexpr = get_default_tdm_layout(*block_shape) + desc = amd_tdm.make_tensor_descriptor(base, shape, strides, block_shape, layout) + return AMDTensorDescriptorArgs(desc, base) + else: + layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for(block_shape, base.dtype.element_ty) + return tma.make_tensor_descriptor(base, shape, strides, block_shape, layout, padding_option) @gluon.jit From 3cb9ee642a7ac53efbd1e59d2f2284ed6b0fa908 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 14 Apr 2026 00:45:55 -0700 Subject: [PATCH 21/26] Fix TranslatorTarget for Python 3.10 compat Use (str, Enum) instead of StrEnum which requires Python 3.11+. Made-with: Cursor --- .../triton/tools/triton_to_gluon_translator/slice_kernel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py index df29b9b5bc74..35153777da78 100644 --- a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py +++ b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py @@ -13,7 +13,7 @@ from collections import OrderedDict from collections.abc import Sequence from dataclasses import dataclass, field -from enum import StrEnum +from enum import Enum from pathlib import Path from types import BuiltinFunctionType, FunctionType, ModuleType from typing import Any, Callable, TypeAlias @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) -class TranslatorTarget(StrEnum): +class TranslatorTarget(str, Enum): """Target architecture for the Triton-to-Gluon translator. Known targets are listed as explicit members for discoverability. From 948972d922d5e66bbd134ef12a0387eb2dc10104 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 14 Apr 2026 23:55:45 -0700 Subject: [PATCH 22/26] Split translator_helpers into target-specific helper modules Split the monolithic translator_helpers.py into: - common_helpers.py: vendor-neutral utilities (layouts, portable ops) - nvidia_helpers.py: NVIDIA-specific helpers (TMA, mbarrier, Blackwell) - amd_helpers.py: AMD-specific helpers (TDM, WMMA, MFMA) Each target module re-exports common helpers via star import so the generated kernel sees a single unified `helpers` namespace. The TranslatorTarget.helpers_module property selects which module to import, so translated kernels no longer pull in unrelated hardware modules. translator_helpers.py is kept as a backward-compat re-export shim. Made-with: Cursor --- .../triton_to_gluon_translator/amd_helpers.py | 411 +++++++ .../common_helpers.py | 277 +++++ .../nvidia_helpers.py | 507 ++++++++ .../slice_kernel.py | 8 + .../triton_to_gluon_translator/translator.py | 2 +- .../translator_helpers.py | 1025 +---------------- 6 files changed, 1210 insertions(+), 1020 deletions(-) create mode 100644 python/triton/tools/triton_to_gluon_translator/amd_helpers.py create mode 100644 python/triton/tools/triton_to_gluon_translator/common_helpers.py create mode 100644 python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py diff --git a/python/triton/tools/triton_to_gluon_translator/amd_helpers.py b/python/triton/tools/triton_to_gluon_translator/amd_helpers.py new file mode 100644 index 000000000000..6b78f2c5bee7 --- /dev/null +++ b/python/triton/tools/triton_to_gluon_translator/amd_helpers.py @@ -0,0 +1,411 @@ +# type: ignore + +import math + +import triton.language as tl +from triton.experimental import gluon +from triton.experimental.gluon import language as ttgl +from triton.experimental.gluon.language.amd.gfx1250 import wmma as amd_wmma +from triton.experimental.gluon.language.amd.gfx1250 import tdm as amd_tdm +from triton.experimental.gluon.language.amd.cdna3 import mfma as amd_mfma +from triton.language.target_info import current_target + +from triton.tools.triton_to_gluon_translator.common_helpers import * # noqa: F401,F403 +from triton.tools.triton_to_gluon_translator.common_helpers import ( + default_blocked_layout, + get_num_threads_per_warp, + tl_dot_decomposed_scale_arg, + tl_trans, +) + + +# ---- architecture detection ---- + + +@gluon.constexpr_function +def _is_gfx1250(target=None): + return target is not None and target.arch == "gfx1250" + + +@gluon.constexpr_function +def _is_cdna(target=None): + return target is not None and target.arch in ("gfx942", "gfx950") + + +@gluon.constexpr_function +def _cdna_version(target=None): + """Returns 3 for gfx942, 4 for gfx950.""" + return 4 if target is not None and target.arch == "gfx950" else 3 + + +# ---- AMD WMMA layout helpers (gfx1250) ---- + + +@gluon.constexpr_function +def compute_warp_bases(num_warps): + """Distribute warps across M/N: first bit to N, rest to M.""" + n_bits = int(math.log2(num_warps)) + if n_bits == 0: + return [] + warp_bases = [[0, 1]] + for i in range(n_bits - 1): + warp_bases.append([1 << i, 0]) + return warp_bases + + +@gluon.constexpr_function +def get_wmma_layout(shape, num_warps): + warp_bases = compute_warp_bases(num_warps) + return ttgl.amd.AMDWMMALayout(3, True, warp_bases, [], [16, 16, 32]) + + +@gluon.constexpr_function +def get_wmma_k_width(a_ty, b_ty): + min_bitwidth = min(a_ty.element_ty.primitive_bitwidth, b_ty.element_ty.primitive_bitwidth) + return max(128 // min_bitwidth, 1) + + +# ---- AMD MFMA layout helpers (cdna3/cdna4) ---- + + +@gluon.constexpr_function +def get_mfma_instr_k(element_bitwidth, target=None): + """K dimension of the MFMA instruction for [32, 32, K].""" + k_bits = 128 if _cdna_version(target) == 3 else 256 + return k_bits // element_bitwidth + + +@gluon.constexpr_function +def get_mfma_layout(num_warps, element_bitwidth, target=None): + instr_k = get_mfma_instr_k(element_bitwidth, target) + return ttgl.amd.AMDMFMALayout( + version=_cdna_version(target), + instr_shape=[32, 32, instr_k], + transposed=True, + warps_per_cta=[num_warps, 1], + ) + + +@gluon.constexpr_function +def get_mfma_k_width(a_ty, b_ty, target=None): + min_bitwidth = min(a_ty.element_ty.primitive_bitwidth, b_ty.element_ty.primitive_bitwidth) + instr_k = get_mfma_instr_k(min_bitwidth, target) + return instr_k // 2 + + +# ---- AMD dot paths ---- + + +@gluon.jit +def tl_dot_wmma(a, b, acc, out_dtype): + """gfx1250 WMMA path.""" + M: ttgl.constexpr = a.type.shape[0] + N: ttgl.constexpr = b.type.shape[1] + num_warps: ttgl.constexpr = ttgl.num_warps() + + wmma_layout: ttgl.constexpr = get_wmma_layout([M, N], num_warps) + k_width: ttgl.constexpr = get_wmma_k_width(a.type, b.type) + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=wmma_layout, k_width=k_width) + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=wmma_layout, k_width=k_width) + + a = ttgl.convert_layout(a, a_layout) + b = ttgl.convert_layout(b, b_layout) + + if acc is not None: + accumulator = ttgl.convert_layout(acc, wmma_layout) + else: + accumulator = ttgl.zeros([M, N], out_dtype, layout=wmma_layout) + + result = amd_wmma(a, b, accumulator) + + if acc is not None: + ret_layout: ttgl.constexpr = acc.type.layout + else: + ret_layout: ttgl.constexpr = default_blocked_layout(result.type.shape, num_warps) + return ttgl.convert_layout(result, ret_layout) + + +@gluon.jit +def tl_dot_mfma(a, b, acc, out_dtype): + """CDNA3/CDNA4 MFMA path.""" + M: ttgl.constexpr = a.type.shape[0] + N: ttgl.constexpr = b.type.shape[1] + num_warps: ttgl.constexpr = ttgl.num_warps() + min_bitwidth: ttgl.constexpr = min(a.type.element_ty.primitive_bitwidth, b.type.element_ty.primitive_bitwidth) + target: ttgl.constexpr = current_target() + + mfma_layout: ttgl.constexpr = get_mfma_layout(num_warps, min_bitwidth, target) + k_width: ttgl.constexpr = get_mfma_k_width(a.type, b.type, target) + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=k_width) + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=k_width) + + a = ttgl.convert_layout(a, a_layout) + b = ttgl.convert_layout(b, b_layout) + + if acc is not None: + accumulator = ttgl.convert_layout(acc, mfma_layout) + else: + accumulator = ttgl.zeros([M, N], out_dtype, layout=mfma_layout) + + result = amd_mfma(a, b, accumulator) + + if acc is not None: + ret_layout: ttgl.constexpr = acc.type.layout + else: + ret_layout: ttgl.constexpr = default_blocked_layout(result.type.shape, num_warps) + return ttgl.convert_layout(result, ret_layout) + + +# ---- AMD dot dispatch ---- + + +@gluon.jit +def tl_dot( + a, + b, + acc=None, + input_precision=None, + allow_tf32=None, + max_num_imprecise_acc=None, + out_dtype=ttgl.float32, +): + target: ttgl.constexpr = current_target() + if _is_gfx1250(target): + return tl_dot_wmma(a, b, acc, out_dtype) + elif _is_cdna(target): + return tl_dot_mfma(a, b, acc, out_dtype) + + +# Defined here (not imported from common) so __globals__ resolves tl_dot to this module's version. +@gluon.jit +def tl_dot_decomposed_block_scales( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc=None, + fast_math=False, + lhs_k_pack=True, + rhs_k_pack=True, + out_dtype=ttgl.float32, +): + if lhs_scale is None and rhs_scale is not None: + lhs_trans = tl_trans(lhs) + rhs_trans = tl_trans(rhs) + if acc is not None: + orig_layout: ttgl.constexpr = acc.type.layout + acc = tl_trans(acc) + result = tl_dot_scaled( + rhs_trans, + rhs_scale, + rhs_format, + lhs_trans, + lhs_scale, + lhs_format, + acc, + fast_math, + lhs_k_pack, + rhs_k_pack, + out_dtype, + ) + result = tl_trans(result) + if acc is not None: + result = ttgl.convert_layout(result, orig_layout) + return result + else: + ttgl.static_assert(not (not lhs_k_pack or not rhs_k_pack), "TODO: support m/n packed formats") + compute_type: ttgl.constexpr = (ttgl.float16 if + (lhs_format == "fp16" or rhs_format == "fp16") else ttgl.bfloat16) + + scale_a = tl_dot_decomposed_scale_arg(lhs, lhs_scale, lhs_format, 0, compute_type, fast_math) + scale_b = tl_dot_decomposed_scale_arg(rhs, rhs_scale, rhs_format, 1, compute_type, fast_math) + + return tl_dot(scale_a, scale_b, acc, out_dtype=out_dtype) + + +@gluon.jit +def tl_dot_scaled( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc=None, + fast_math=False, + lhs_k_pack=True, + rhs_k_pack=True, + out_dtype=ttgl.float32, +): + return tl_dot_decomposed_block_scales( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc, + fast_math, + lhs_k_pack, + rhs_k_pack, + out_dtype, + ) + + +# ---- AMD TDM tensor descriptors (gfx1250 only) ---- + + +@gluon.constexpr_function +def get_default_tdm_layout(*block_shape): + block_shape = list(block_shape) + return ttgl.PaddedSharedLayout.with_identity_for( + [[block_shape[-1], 4]], + block_shape, + list(range(len(block_shape) - 1, -1, -1)), + ) + + +@tl.core._aggregate +class AMDTensorDescriptorArgs: + """Wraps a real TDM descriptor alongside the original base pointer. + + The base_ptr is needed by gather/scatter to recreate the descriptor with a different + block_shape -- Triton uses block_shape=[1, N] but TDM hardware requires [num_indices, N]. + Shape, strides, and block_shape are read from desc (type metadata gives plain Python ints + for block_shape, tuples for shape/strides).""" + desc: amd_tdm.tensor_descriptor + base_ptr: tl.core.tensor + + +@gluon.jit +def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option: ttgl.constexpr = "zero"): + ttgl.static_assert(_is_gfx1250(current_target()), "tl_make_tensor_descriptor requires gfx1250 target") + layout: ttgl.constexpr = get_default_tdm_layout(*block_shape) + desc = amd_tdm.make_tensor_descriptor(base, shape, strides, block_shape, layout) + return AMDTensorDescriptorArgs(desc, base) + + +# ---- AMD obj dispatch ---- + + +@gluon.jit +def tl_obj_load_amd(desc, offsets): + smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout) + amd_tdm.async_load(desc, offsets, smem) + amd_tdm.async_wait(0) + ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) + return smem.load(ret_layout) + + +@gluon.jit +def tl_obj_store_amd(desc, offsets, value): + smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value) + amd_tdm.async_store(desc, offsets, smem) + amd_tdm.async_wait(0) + + +@gluon.jit +def tl_obj_store(obj, offsets, value): + if isinstance(obj, AMDTensorDescriptorArgs): + tl_obj_store_amd(obj.desc, offsets, value) + elif isinstance(obj, amd_tdm.tensor_descriptor): + tl_obj_store_amd(obj, offsets, value) + else: + return obj.store(offsets, value) + + +@gluon.jit +def tl_obj_load(obj, offsets): + if isinstance(obj, AMDTensorDescriptorArgs): + return tl_obj_load_amd(obj.desc, offsets) + elif isinstance(obj, amd_tdm.tensor_descriptor): + return tl_obj_load_amd(obj, offsets) + else: + return obj.load(offsets) + + +@gluon.jit +def tl_obj_gather_amd(desc_args, x_offsets, y_offset): + NUM_IDX: ttgl.constexpr = x_offsets.shape[0] + BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] + smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) + gather_block_shape: ttgl.constexpr = [NUM_IDX, BLOCK_N] + gather_desc = amd_tdm.make_tensor_descriptor(desc_args.base_ptr, desc_args.desc.shape, desc_args.desc.strides, + gather_block_shape, smem_layout) + num_warps: ttgl.constexpr = ttgl.num_warps() + gather_shape: ttgl.constexpr = gather_desc.block_shape + idx_base: ttgl.constexpr = ttgl.BlockedLayout([gather_shape[0], 1], + [1, get_num_threads_per_warp(current_target())], [1, num_warps], + [1, 0]) + idx_layout: ttgl.constexpr = ttgl.SliceLayout(1, idx_base) + x_offsets = ttgl.convert_layout(x_offsets, idx_layout) + alloc = ttgl.allocate_shared_memory(desc_args.desc.dtype, list(gather_shape), smem_layout) + y_off = ttgl.to_tensor(y_offset) + amd_tdm.async_gather(gather_desc, x_offsets, y_off, alloc) + amd_tdm.async_wait(0) + ret_layout: ttgl.constexpr = default_blocked_layout(list(gather_shape), num_warps, current_target()) + out = alloc.load(ret_layout) + return out + + +@gluon.jit +def tl_obj_scatter_amd(desc_args, value, x_offsets, y_offset): + NUM_IDX: ttgl.constexpr = x_offsets.shape[0] + BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] + smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) + scatter_block_shape: ttgl.constexpr = [NUM_IDX, BLOCK_N] + scatter_desc = amd_tdm.make_tensor_descriptor(desc_args.base_ptr, desc_args.desc.shape, desc_args.desc.strides, + scatter_block_shape, smem_layout) + num_warps: ttgl.constexpr = ttgl.num_warps() + scatter_shape: ttgl.constexpr = scatter_desc.block_shape + idx_base: ttgl.constexpr = ttgl.BlockedLayout([scatter_shape[0], 1], + [1, get_num_threads_per_warp(current_target())], [1, num_warps], + [1, 0]) + idx_layout: ttgl.constexpr = ttgl.SliceLayout(1, idx_base) + x_offsets = ttgl.convert_layout(x_offsets, idx_layout) + alloc = ttgl.allocate_shared_memory(desc_args.desc.dtype, list(scatter_shape), smem_layout, value) + y_off = ttgl.to_tensor(y_offset) + amd_tdm.async_scatter(scatter_desc, x_offsets, y_off, alloc) + amd_tdm.async_wait(0) + + +@gluon.jit +def tl_obj_gather(obj, x_offsets, y_offset): + if isinstance(obj, AMDTensorDescriptorArgs): + return tl_obj_gather_amd(obj, x_offsets, y_offset) + else: + return obj.gather(x_offsets, y_offset) + + +@gluon.jit +def tl_obj_scatter(obj, value, x_offsets, y_offset): + if isinstance(obj, AMDTensorDescriptorArgs): + tl_obj_scatter_amd(obj, value, x_offsets, y_offset) + else: + obj.scatter(value, x_offsets, y_offset) + + +# ---- AMD host-side descriptor ---- + + +def convert_host_descriptor(desc): + + def torch_dtype_to_triton(dtype): + import torch + + if dtype == torch.float8_e5m2: + return ttgl.float8e5 + if dtype == torch.float8_e4m3fn: + return ttgl.float8e4nv + return getattr(ttgl, str(dtype).split(".")[1]) + + from triton.tools.tensor_descriptor import TensorDescriptor + + assert isinstance(desc, TensorDescriptor) + block_shape = desc.block_shape + tensor = desc.base + + layout = get_default_tdm_layout(*block_shape) + return gluon.amd.gfx1250.TensorDescriptor(tensor, list(desc.shape), list(desc.strides), block_shape, layout) diff --git a/python/triton/tools/triton_to_gluon_translator/common_helpers.py b/python/triton/tools/triton_to_gluon_translator/common_helpers.py new file mode 100644 index 000000000000..6630438d5eb4 --- /dev/null +++ b/python/triton/tools/triton_to_gluon_translator/common_helpers.py @@ -0,0 +1,277 @@ +# type: ignore + +import math + +import triton.language as tl +from triton.experimental import gluon +from triton.experimental.gluon import language as ttgl + +# hack to workaround limited dependencies tracking. +# TODO: fix this by pulling imports into the generated file. +from triton.language.target_info import current_target # noqa: F401 + + +# ---- layout utilities ---- + + +@gluon.constexpr_function +def get_num_threads_per_warp(target=None) -> ttgl.constexpr: + if target is None: + target = current_target() + if target is not None and target.backend == "hip": + gfx_major = int(target.arch[3:-2]) + return ttgl.constexpr(32 if gfx_major >= 10 else 64) + return ttgl.constexpr(32) + + +@gluon.jit +def get_num_threads_per_program(): + return ttgl.num_warps() * get_num_threads_per_warp(current_target()) + + +@gluon.constexpr_function +def default_blocked_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr, target=None) -> ttgl.constexpr: + rank = len(shape) + size_per_thread = [1] * rank + threads_per_warp = [1] * rank + # TODO: pick a better layout based on shape. Using this allows to not have to convert layout when broadcasting but may blow up register pressure. + threads_per_warp[rank - 1] = get_num_threads_per_warp(target) + warps_per_cta = [1] * rank + warps_per_cta[0] = num_warps + order = list(range(rank - 1, -1, -1)) + return ttgl.BlockedLayout( + size_per_thread=size_per_thread, + threads_per_warp=threads_per_warp, + warps_per_cta=warps_per_cta, + order=order, + ) + + +@gluon.constexpr_function +def get_swizzle_byte_width(bitwidth): + swizzle = min(bitwidth, 128) + swizzle = 0 if swizzle < 32 else swizzle + return swizzle + + +@gluon.constexpr_function +def get_int_type(bitwidth): + if bitwidth == 64: + return ttgl.int64 + elif bitwidth == 32: + return ttgl.int32 + elif bitwidth == 16: + return ttgl.int16 + elif bitwidth == 8: + return ttgl.int8 + else: + assert False, f"Unsupported bitwidth: {bitwidth}" + + +# ---- portable ops ---- + + +@gluon.jit +def tl_arange(start: ttgl.constexpr, stop: ttgl.constexpr = None): + layout: ttgl.constexpr = default_blocked_layout([stop - start], ttgl.num_warps()) + return ttgl.arange(start, stop, layout=layout) + + +@gluon.jit +def tl_full(shape, value, dtype=None): + layout: ttgl.constexpr = default_blocked_layout(shape, ttgl.num_warps()) + return ttgl.full(shape, value, dtype, layout=layout) + + +@gluon.jit +def tl_trans(value, *dims): + return value.trans(*dims) + + +@gluon.constexpr_function +def cat_permute_order(rank, dim): + order = list(range(rank)) + order.insert(dim, rank) + return order + + +@gluon.constexpr_function +def cat_result_shape(input_shape, dim): + result_shape = list(input_shape) + result_shape[dim] *= 2 + return result_shape + + +@gluon.jit +def tl_cat(input, other, can_reorder=False, dim=0): + c = ttgl.join(input, other) + order: ttgl.constexpr = cat_permute_order(len(input.shape), dim) + c = ttgl.permute(c, order) + shape: ttgl.constexpr = cat_result_shape(input.shape, dim) + c = ttgl.reshape(c, shape) + return reset_to_default_layout(c) + + +@gluon.jit +def reset_to_default_layout(value): + ty: ttgl.constexpr = value.type + if isinstance(ty, ttgl.tuple_type): + out = () + for i in ttgl.static_range(len(value)): + r = ttgl.convert_layout(value[i], layout=default_blocked_layout(value[i].type.shape, ttgl.num_warps())) + out = out + (r, ) + return out + elif isinstance(value, ttgl.tensor) and isinstance(value.type, ttgl.distributed_type): + layout: ttgl.constexpr = default_blocked_layout(ty.shape, ttgl.num_warps()) + return ttgl.convert_layout(value, layout=layout) + else: + return value + + +@gluon.constexpr_function +def get_split_src_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr, target=None) -> ttgl.constexpr: + rank = len(shape) + size_per_thread = [1 if i != rank - 1 else 2 for i in range(rank)] + threads_per_warp = [1 for _ in range(rank)] + remaining_threads = get_num_threads_per_warp(target) + for dim in range(rank - 2, -1, -1): + threads_per_warp[dim] = min(shape[dim], remaining_threads) + remaining_threads = remaining_threads // threads_per_warp[dim] + warps_per_cta = [1 for _ in range(rank)] + warps_per_cta[0] = num_warps + order = list(range(rank - 1, -1, -1)) + return ttgl.BlockedLayout( + size_per_thread=size_per_thread, + threads_per_warp=threads_per_warp, + warps_per_cta=warps_per_cta, + order=order, + ) + + +@gluon.jit +def set_split_src_layout(value): + layout: ttgl.constexpr = get_split_src_layout(value.type.shape, ttgl.num_warps()) + return ttgl.convert_layout(value, layout=layout) + + +@gluon.constexpr_function +def build_expand_dims_layout(shape, expand_dims, num_warps): + if isinstance(shape, ttgl.tuple): + shape = shape.values + assert isinstance(shape, list), (f"expected shape to be a list, got {shape} which is {type(shape)}") + parent_shape = list(shape) + for dim in expand_dims: + parent_shape.insert(dim, 1) + layout = default_blocked_layout(parent_shape, num_warps) + for dim in reversed(expand_dims): + layout = ttgl.SliceLayout(dim=dim, parent=layout) + return layout + + +@gluon.jit +def convert_to_expand_dims_layout(value, expand_dims: list[int]): + layout: ttgl.constexpr = build_expand_dims_layout(value.shape, expand_dims, ttgl.num_warps()) + return ttgl.convert_layout(value, layout) + + +# ---- dot-scaled sub-helpers (vendor-neutral) ---- + + +@gluon.jit +def tl_dot_decomposed_scale_to_16(scale, compute_type): + large_fp_type: ttgl.constexpr = ttgl.float32 if compute_type == ttgl.float16 else compute_type + int_width: ttgl.constexpr = large_fp_type.primitive_bitwidth + int_type: ttgl.constexpr = get_int_type(int_width) + + zexted = ttgl.cast(scale, int_type) + shift_value: ttgl.constexpr = large_fp_type.fp_mantissa_width + shl_res = zexted << shift_value + scale_fp = ttgl.cast(shl_res, large_fp_type, bitcast=True) + if large_fp_type != compute_type: + scale_fp = ttgl.cast(scale_fp, compute_type) + return scale_fp + + +@gluon.constexpr_function +def tl_dot_get_expand_dims_layout(scale_ty, num_warps, rank): + shape = scale_ty.shape.values + [1] + blocked = default_blocked_layout(shape, num_warps) + slice = ttgl.SliceLayout(rank, blocked) + return slice + + +@gluon.constexpr_function +def tl_dot_get_permute_order(rank, dim): + order = list(range(rank)) + order.insert(dim + 1, rank) + return order + + +@gluon.constexpr_function +def tl_dot_get_reshape_shape(scale_ty, dim): + shape = list(scale_ty.shape.values) + shape.pop() + shape[dim] *= 32 + return shape + + +@gluon.jit +def tl_dot_decomposed_broadcast_scale(scale, dim): + scale_ty: ttgl.constexpr = scale.type + rank: ttgl.constexpr = len(scale_ty.shape) + + num_warps: ttgl.constexpr = ttgl.num_warps() + slice_enc: ttgl.constexpr = tl_dot_get_expand_dims_layout(scale_ty, num_warps, rank) + scale = ttgl.convert_layout(scale, slice_enc) + expand_scale = scale.expand_dims(rank) + broadcast_scale = expand_scale.broadcast_to(scale.type.shape + (32, )) + permute_order: ttgl.constexpr = tl_dot_get_permute_order(rank, dim) + transposed_scale = broadcast_scale.permute(permute_order) + reshape_shape: ttgl.constexpr = tl_dot_get_reshape_shape(broadcast_scale.type, dim) + return transposed_scale.reshape(reshape_shape) + + +@gluon.constexpr_function +def tl_dot_decomposed_get_transposed_order(rank): + assert rank >= 2 + order = list(range(rank - 2)) + order += [rank - 1, rank - 2] + return order + + +@gluon.jit +def tl_dot_decomposed_extend_and_broadcast_scale(v, scale, compute_type, operand_index): + rank: ttgl.constexpr = len(v.type.shape) + k_dim: ttgl.constexpr = rank - 1 if operand_index == 0 else rank - 2 + + if operand_index == 1: + order: ttgl.constexpr = tl_dot_decomposed_get_transposed_order(rank) + scale = ttgl.permute(scale, order) + + scale16 = tl_dot_decomposed_scale_to_16(scale, compute_type) + reshape_scale = tl_dot_decomposed_broadcast_scale(scale16, k_dim) + return ttgl.convert_layout(reshape_scale, v.type.layout), scale + + +@gluon.jit +def tl_dot_decomposed_mask_nan(mxfp, scale, fast_math): + ttgl.static_assert(fast_math, "TODO: support non-fast-math") + return mxfp + + +@gluon.jit +def tl_dot_decomposed_scale_arg(v, scale, arg_format, operand_index, compute_type, fast_math): + is_fp4: ttgl.constexpr = arg_format == "e2m1" + rank: ttgl.constexpr = len(v.type.shape) + k_dim: ttgl.constexpr = rank - 1 if operand_index == 0 else rank - 2 + + if is_fp4: + v = ttgl.fp4_to_fp(v, compute_type, k_dim) + else: + v = ttgl.cast(v, compute_type) + if scale is None: + return v + else: + reshape_scale, scale = tl_dot_decomposed_extend_and_broadcast_scale(v, scale, compute_type, operand_index) + mxfp = ttgl.mul(v, reshape_scale) + return tl_dot_decomposed_mask_nan(mxfp, scale, fast_math) diff --git a/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py b/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py new file mode 100644 index 000000000000..e005f0e64896 --- /dev/null +++ b/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py @@ -0,0 +1,507 @@ +# type: ignore + +from triton.experimental import gluon +from triton.experimental.gluon import language as ttgl +from triton.experimental.gluon.language.nvidia.ampere import mma_v2 +from triton.experimental.gluon.language.nvidia.blackwell import ( + TensorMemoryLayout, + TensorMemoryScalesLayout, + allocate_tensor_memory, + tcgen05_commit, + tcgen05_mma, + tcgen05_mma_scaled, +) +from triton.experimental.gluon.language.nvidia.blackwell import tma as tma_blackwell +from triton.experimental.gluon.language.nvidia.hopper import fence_async_shared, mbarrier, tma + +from triton.tools.triton_to_gluon_translator.common_helpers import * # noqa: F401,F403 +from triton.tools.triton_to_gluon_translator.common_helpers import ( + default_blocked_layout, + get_num_threads_per_warp, + tl_dot_decomposed_scale_arg, + tl_trans, +) + + +# ---- NVIDIA MMA sync (Ampere) ---- + + +@gluon.constexpr_function +def tl_dot_mma_sync_layout(shape, num_warps): + rank = len(shape) + assert rank in [ + 2, + 3, + ], "MMA sync only supports 2D shapes or 3D shapes with a batch outer dimension" + if rank == 2: + return ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[num_warps, 1], instr_shape=[16, 8]) + return ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[num_warps, 1, 1], instr_shape=[1, 16, 8]) + + +@gluon.constexpr_function +def tl_dot_mma_sync_k_width(a_ty, b_ty): + a_bitwidth = a_ty.element_ty.primitive_bitwidth + b_bitwidth = b_ty.element_ty.primitive_bitwidth + min_bitwidth = min(a_bitwidth, b_bitwidth) + return max(32 // min_bitwidth, 1) + + +@gluon.jit +def tl_dot_mma_sync(a, b, acc_init=None, input_precision=None, out_dtype=ttgl.float32): + mma_layout: ttgl.constexpr = tl_dot_mma_sync_layout(a.type.shape, ttgl.num_warps()) + k_width: ttgl.constexpr = tl_dot_mma_sync_k_width(a.type, b.type) + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=mma_layout, operand_index=0, k_width=k_width) + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=mma_layout, operand_index=1, k_width=k_width) + a = ttgl.convert_layout(a, a_layout) + b = ttgl.convert_layout(b, b_layout) + if acc_init is not None: + acc = ttgl.convert_layout(acc_init, mma_layout) + else: + acc = ttgl.full([a.shape[0], a.shape[1], b.shape[2]], 0.0, out_dtype, layout=mma_layout) + result = mma_v2(a, b, acc, input_precision) + if acc_init is not None: + layout: ttgl.constexpr = acc_init.type.layout + else: + layout: ttgl.constexpr = default_blocked_layout(result.type.shape, ttgl.num_warps()) + result = ttgl.convert_layout(result, layout) + return result + + +# ---- NVIDIA Blackwell dot ---- + + +@gluon.constexpr_function +def tl_dot_mmav5_supported(a_ty, b_ty, num_warps, input_precision, allow_tf32, max_num_imprecise_acc): + assert max_num_imprecise_acc is None, ("max_num_imprecise_acc only applies to Hopper warp_group_dot") + assert input_precision is None or allow_tf32 is None, ( + "Only one of input_precision and allow_tf32 can be specified") + if input_precision is None and (allow_tf32 or allow_tf32 is None): + input_precision = "tf32" + + M = a_ty.shape[0] + N = b_ty.shape[1] + K = a_ty.shape[1] + min_K = 256 // a_ty.element_ty.primitive_bitwidth + if a_ty.element_ty.is_int() or b_ty.element_ty.is_int(): + return False + if (min(a_ty.element_ty.primitive_bitwidth, b_ty.element_ty.primitive_bitwidth) >= 32 + and input_precision != "tf32"): + return False + return (num_warps in [4, 8] and len(a_ty.shape) == 2 and len(b_ty.shape) == 2 and K >= min_K and M >= 64 + and N >= 16) + + +@gluon.constexpr_function +def get_shared_memory_mma_layout(type, operand_index, allow_transpose, is_fp4_padded=False, force_transpose=False): + if not allow_transpose: + if operand_index == 1: + transposed = True + else: + transposed = False + if force_transpose: + transposed = not transposed + else: + transposed = operand_index == 1 + + shape = type.shape + swizzle_byte_width = 0 + ele_bit_width = type.element_ty.primitive_bitwidth + packing_factor = 2 if is_fp4_padded else 1 + + contig_dim_size_in_byte = ((shape[0] if transposed else shape[1]) * packing_factor * ele_bit_width // 8) + if contig_dim_size_in_byte >= 128 and contig_dim_size_in_byte % 128 == 0: + swizzle_byte_width = 128 + elif contig_dim_size_in_byte >= 64 and contig_dim_size_in_byte % 64 == 0: + swizzle_byte_width = 64 + elif contig_dim_size_in_byte >= 32 and contig_dim_size_in_byte % 32 == 0: + swizzle_byte_width = 32 + else: + swizzle_byte_width = 0 + + flatten_outer_dim = 1 + for dim in shape: + flatten_outer_dim *= dim + if len(shape) < 2 or flatten_outer_dim < 8: + swizzle_byte_width = 0 + return ttgl.NVMMASharedLayout( + swizzle_byte_width=swizzle_byte_width, + transposed=transposed, + element_bitwidth=ele_bit_width, + rank=len(shape), + fp4_padded=is_fp4_padded, + ) + + +@gluon.jit +def get_shared_memory_mma_operand(value, operand_index, allow_transpose, is_fp4_padded=False, force_transpose=False): + layout: ttgl.constexpr = get_shared_memory_mma_layout(value.type, operand_index, allow_transpose, is_fp4_padded, + force_transpose) + return ttgl.allocate_shared_memory(value.dtype, value.shape, layout, value) + + +@gluon.jit +def tl_dot_blackwell( + a, + b, + acc=None, + input_precision=None, + allow_tf32=None, + max_num_imprecise_acc=None, + out_dtype=ttgl.float32, +): + M: ttgl.constexpr = a.type.shape[0] + N: ttgl.constexpr = b.type.shape[1] + + allow_transpose = not a.type.element_ty.is_fp32() + a_smem = get_shared_memory_mma_operand(a, 0, allow_transpose) + b_smem = get_shared_memory_mma_operand(b, 1, allow_transpose) + + m: ttgl.constexpr = 128 if M >= 128 else 64 + n: ttgl.constexpr = 256 if N >= 256 else N + + acc_dtype: ttgl.constexpr = acc.dtype if acc is not None else out_dtype + col_stride: ttgl.constexpr = 32 // acc_dtype.primitive_bitwidth + acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout([m, n], col_stride=col_stride) + acc_tmem = allocate_tensor_memory(acc_dtype, [M, N], acc_tmem_layout) + tmem_reg_layout: ttgl.constexpr = acc_tmem.get_reg_layout() + if acc is not None: + acc_temp = ttgl.convert_layout(acc, tmem_reg_layout) + else: + acc_temp = ttgl.zeros([M, N], out_dtype, layout=tmem_reg_layout) + acc_tmem.store(acc_temp) + fence_async_shared() + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + tcgen05_mma(a_smem, b_smem, acc_tmem, use_acc=True) + tcgen05_commit(bar) + mbarrier.wait(bar, phase=0) + mbarrier.invalidate(bar) + + out = acc_tmem.load() + ret_layout: ttgl.constexpr = default_blocked_layout([M, N], ttgl.num_warps()) + out = ttgl.convert_layout(out, ret_layout) + return out + + +# ---- NVIDIA dot dispatch ---- + + +@gluon.jit +def tl_dot( + a, + b, + acc=None, + input_precision=None, + allow_tf32=None, + max_num_imprecise_acc=None, + out_dtype=ttgl.float32, +): + num_warps: ttgl.constexpr = ttgl.num_warps() + if tl_dot_mmav5_supported(a.type, b.type, num_warps, input_precision, allow_tf32, max_num_imprecise_acc): + return tl_dot_blackwell(a, b, acc, input_precision, allow_tf32, max_num_imprecise_acc, out_dtype) + else: + return tl_dot_mma_sync(a, b, acc, input_precision, out_dtype) + + +# ---- NVIDIA dot-scaled ---- + + +@gluon.constexpr_function +def tl_dot_scaled_mmav5_supported(a_ty, b_ty, num_warps): + M = a_ty.shape[0] + N = b_ty.shape[1] + K = a_ty.shape[1] + min_K = 256 // a_ty.element_ty.primitive_bitwidth + return (num_warps in [4, 8] and len(a_ty.shape) == 2 and len(b_ty.shape) == 2 and K >= min_K and M >= 128 + and N >= 16) + + +@gluon.jit +def tl_dot_scaled_blackwell( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc=None, + fast_math=False, + lhs_k_pack=True, + rhs_k_pack=True, + out_dtype=ttgl.float32, +): + is_a_fp4: ttgl.constexpr = lhs_format == "e2m1" + is_b_fp4: ttgl.constexpr = rhs_format == "e2m1" + + mixed_prec: ttgl.constexpr = lhs_format != rhs_format + is_a_mixed_prec_fp4: ttgl.constexpr = mixed_prec and is_a_fp4 + is_b_mixed_prec_fp4: ttgl.constexpr = mixed_prec and not is_a_fp4 and is_b_fp4 + + is_mmav5_fp4_padded_a: ttgl.constexpr = is_a_mixed_prec_fp4 or not lhs_k_pack + is_mmav5_fp4_padded_b: ttgl.constexpr = is_b_mixed_prec_fp4 or not rhs_k_pack + + a_smem = get_shared_memory_mma_operand( + lhs, + 0, + allow_transpose=not is_a_fp4, + is_fp4_padded=is_mmav5_fp4_padded_a, + force_transpose=not lhs_k_pack, + ) + b_smem = get_shared_memory_mma_operand( + rhs, + 1, + allow_transpose=not is_b_fp4, + is_fp4_padded=is_mmav5_fp4_padded_b, + force_transpose=not rhs_k_pack, + ) + + M: ttgl.constexpr = lhs.type.shape[0] + N: ttgl.constexpr = rhs.type.shape[1] + + m: ttgl.constexpr = 128 + n: ttgl.constexpr = 256 if N >= 256 else N + + acc_dtype: ttgl.constexpr = acc.dtype if acc is not None else out_dtype + col_stride: ttgl.constexpr = 32 // acc_dtype.primitive_bitwidth + acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout([m, n], col_stride=col_stride) + acc_tmem = allocate_tensor_memory(acc_dtype, [M, N], acc_tmem_layout) + tmem_reg_layout: ttgl.constexpr = acc_tmem.get_reg_layout() + if acc is not None: + acc_temp = ttgl.convert_layout(acc, tmem_reg_layout) + else: + acc_temp = ttgl.zeros([M, N], out_dtype, layout=tmem_reg_layout) + acc_tmem.store(acc_temp) + fence_async_shared() + + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + scale_layout: ttgl.constexpr = TensorMemoryScalesLayout() + a_scale_tmem = allocate_tensor_memory(lhs_scale.dtype, lhs_scale.shape, scale_layout) + b_scale_tmem = allocate_tensor_memory(rhs_scale.dtype, rhs_scale.shape, scale_layout) + scale_layout_reg_lhs: ttgl.constexpr = a_scale_tmem.get_reg_layout() + scale_layout_reg_rhs: ttgl.constexpr = b_scale_tmem.get_reg_layout() + lhs_scale = ttgl.convert_layout(lhs_scale, scale_layout_reg_lhs) + rhs_scale = ttgl.convert_layout(rhs_scale, scale_layout_reg_rhs) + a_scale_tmem.store(lhs_scale) + b_scale_tmem.store(rhs_scale) + + tcgen05_mma_scaled(a_smem, b_smem, acc_tmem, a_scale_tmem, b_scale_tmem, lhs_format, rhs_format, use_acc=True) + tcgen05_commit(bar) + mbarrier.wait(bar, phase=0) + mbarrier.invalidate(bar) + out = acc_tmem.load() + ret_layout: ttgl.constexpr = default_blocked_layout([M, N], ttgl.num_warps()) + out = ttgl.convert_layout(out, ret_layout) + return out + + +# Defined here (not imported from common) so __globals__ resolves tl_dot to this module's version. +@gluon.jit +def tl_dot_decomposed_block_scales( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc=None, + fast_math=False, + lhs_k_pack=True, + rhs_k_pack=True, + out_dtype=ttgl.float32, +): + if lhs_scale is None and rhs_scale is not None: + lhs_trans = tl_trans(lhs) + rhs_trans = tl_trans(rhs) + if acc is not None: + orig_layout: ttgl.constexpr = acc.type.layout + acc = tl_trans(acc) + result = tl_dot_scaled( + rhs_trans, + rhs_scale, + rhs_format, + lhs_trans, + lhs_scale, + lhs_format, + acc, + fast_math, + lhs_k_pack, + rhs_k_pack, + out_dtype, + ) + result = tl_trans(result) + if acc is not None: + result = ttgl.convert_layout(result, orig_layout) + return result + else: + ttgl.static_assert(not (not lhs_k_pack or not rhs_k_pack), "TODO: support m/n packed formats") + compute_type: ttgl.constexpr = (ttgl.float16 if + (lhs_format == "fp16" or rhs_format == "fp16") else ttgl.bfloat16) + + scale_a = tl_dot_decomposed_scale_arg(lhs, lhs_scale, lhs_format, 0, compute_type, fast_math) + scale_b = tl_dot_decomposed_scale_arg(rhs, rhs_scale, rhs_format, 1, compute_type, fast_math) + + return tl_dot(scale_a, scale_b, acc, out_dtype=out_dtype) + + +@gluon.jit +def tl_dot_scaled( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc=None, + fast_math=False, + lhs_k_pack=True, + rhs_k_pack=True, + out_dtype=ttgl.float32, +): + if (tl_dot_scaled_mmav5_supported(lhs.type, rhs.type, ttgl.num_warps()) + and lhs_scale is not None and rhs_scale is not None): + return tl_dot_scaled_blackwell( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc, + fast_math, + lhs_k_pack, + rhs_k_pack, + out_dtype, + ) + else: + return tl_dot_decomposed_block_scales( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc, + fast_math, + lhs_k_pack, + rhs_k_pack, + out_dtype, + ) + + +# ---- NVIDIA TMA tensor descriptors ---- + + +@gluon.jit +def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option: ttgl.constexpr = "zero"): + layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for(block_shape, base.dtype.element_ty) + return tma.make_tensor_descriptor(base, shape, strides, block_shape, layout, padding_option) + + +@gluon.jit +def tl_store_tensor_descriptor(desc, offsets, value): + alloc = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value) + fence_async_shared() + tma.async_copy_shared_to_global(desc, offsets, alloc) + tma.store_wait(0) + alloc._keep_alive() + + +@gluon.jit +def tl_load_tensor_descriptor(desc, offsets): + smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + mbarrier.expect(bar, desc.block_type.nbytes) + tma.async_copy_global_to_shared(desc, offsets, bar, smem) + mbarrier.wait(bar, phase=0) + mbarrier.invalidate(bar) + ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) + out = smem.load(ret_layout) + return out + + +# ---- NVIDIA obj dispatch ---- + + +@gluon.jit +def tl_obj_store(obj, offsets, value): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + return tl_store_tensor_descriptor(obj, offsets, value) + else: + return obj.store(offsets, value) + + +@gluon.jit +def tl_obj_load(obj, offsets): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + return tl_load_tensor_descriptor(obj, offsets) + else: + return obj.load(offsets) + + +@gluon.jit +def tl_obj_gather(obj, x_offsets, y_offset): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + desc = obj + desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] + alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( + 0, + ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), + ) + x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) + mbarrier.expect(bar, x_offsets.shape[0] * obj.block_type.nbytes) + tma_blackwell.async_gather(desc, x_offsets, y_offset, bar, alloc) + mbarrier.wait(bar, phase=0) + mbarrier.invalidate(bar) + ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) + out = alloc.load(ret_layout) + return out + else: + return obj.gather(x_offsets, y_offset) + + +@gluon.jit +def tl_obj_scatter(obj, value, x_offsets, y_offset): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + desc = obj + desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] + alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout, value) + fence_async_shared() + x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( + 0, + ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), + ) + x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) + tma_blackwell.async_scatter(desc, x_offsets, y_offset, alloc) + tma.store_wait(0) + else: + obj.scatter(value, x_offsets, y_offset) + + +# ---- NVIDIA host-side descriptor ---- + + +def convert_host_descriptor(desc): + + def torch_dtype_to_triton(dtype): + import torch + + if dtype == torch.float8_e5m2: + return ttgl.float8e5 + if dtype == torch.float8_e4m3fn: + return ttgl.float8e4nv + return getattr(ttgl, str(dtype).split(".")[1]) + + from triton.tools.tensor_descriptor import TensorDescriptor + + assert isinstance(desc, TensorDescriptor) + block_shape = desc.block_shape + dtype = desc.base.dtype + tensor = desc.base + + layout = ttgl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(dtype)) + return gluon.nvidia.hopper.TensorDescriptor(tensor, desc.shape, desc.strides, block_shape, layout) diff --git a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py index 35153777da78..444f5c24cb34 100644 --- a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py +++ b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py @@ -65,6 +65,14 @@ def tensor_descriptor_import(self) -> str: return "from triton.experimental.gluon.language.amd.gfx1250.tdm import tensor_descriptor" return "from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor" + @property + def helpers_module(self) -> str: + """Return the helpers module path for this target.""" + base = "triton.tools.triton_to_gluon_translator" + if self.is_amd: + return f"{base}.amd_helpers" + return f"{base}.nvidia_helpers" + @dataclass class GlobalVariable: diff --git a/python/triton/tools/triton_to_gluon_translator/translator.py b/python/triton/tools/triton_to_gluon_translator/translator.py index 92affdd4cbfe..12fcae55205a 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator.py +++ b/python/triton/tools/triton_to_gluon_translator/translator.py @@ -125,7 +125,7 @@ def __post_init__(self) -> None: self.imports.add("import triton.experimental.gluon as gluon") self.imports.add("import triton.experimental.gluon.language as gl") - self.imports.add("import triton.tools.triton_to_gluon_translator.translator_helpers as helpers") + self.imports.add(f"import {self.target.helpers_module} as helpers") self.tensor_member_match_fns = ["reshape", "trans", "permute", "split", "reduce", "sum"] diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py index 9d930d19cd6a..91ae89264d1c 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py @@ -1,1021 +1,8 @@ # type: ignore +# Backward-compat shim: re-exports from the split helper modules so that +# existing imports (e.g. ``from translator_helpers import convert_host_descriptor``) +# continue to work. -import math - -import triton.language as tl -from triton.experimental import gluon -from triton.experimental.gluon import language as ttgl -from triton.experimental.gluon.language.nvidia.ampere import mma_v2 -from triton.experimental.gluon.language.nvidia.blackwell import ( - TensorMemoryLayout, - TensorMemoryScalesLayout, - allocate_tensor_memory, - tcgen05_commit, - tcgen05_mma, - tcgen05_mma_scaled, -) -from triton.experimental.gluon.language.nvidia.blackwell import tma as tma_blackwell -from triton.experimental.gluon.language.nvidia.hopper import fence_async_shared, mbarrier, tma -from triton.experimental.gluon.language.amd.gfx1250 import wmma as amd_wmma - -# hack to workaround limited dependencies tracking. -# TODO: fix this by pulling imports into the generated file. -from triton.language.target_info import current_target # noqa: F401 -from triton.experimental.gluon.language.amd.gfx1250 import tdm as amd_tdm -from triton.experimental.gluon.language.amd.cdna3 import mfma as amd_mfma - - -@gluon.constexpr_function -def tl_dot_mma_sync_layout(shape, num_warps): - rank = len(shape) - assert rank in [ - 2, - 3, - ], "MMA sync only supports 2D shapes or 3D shapes with a batch outer dimension" - if rank == 2: - return ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[num_warps, 1], instr_shape=[16, 8]) - return ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[num_warps, 1, 1], instr_shape=[1, 16, 8]) - - -@gluon.constexpr_function -def tl_dot_mma_sync_k_width(a_ty, b_ty): - a_bitwidth = a_ty.element_ty.primitive_bitwidth - b_bitwidth = b_ty.element_ty.primitive_bitwidth - min_bitwidth = min(a_bitwidth, b_bitwidth) - return max(32 // min_bitwidth, 1) - - -@gluon.jit -def tl_dot_mma_sync(a, b, acc_init=None, input_precision=None, out_dtype=ttgl.float32): - mma_layout: ttgl.constexpr = tl_dot_mma_sync_layout(a.type.shape, ttgl.num_warps()) - k_width: ttgl.constexpr = tl_dot_mma_sync_k_width(a.type, b.type) - a_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=mma_layout, operand_index=0, k_width=k_width) - b_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=mma_layout, operand_index=1, k_width=k_width) - a = ttgl.convert_layout(a, a_layout) - b = ttgl.convert_layout(b, b_layout) - if acc_init is not None: - acc = ttgl.convert_layout(acc_init, mma_layout) - else: - acc = ttgl.full([a.shape[0], a.shape[1], b.shape[2]], 0.0, out_dtype, layout=mma_layout) - result = mma_v2(a, b, acc, input_precision) - if acc_init is not None: - layout: ttgl.constexpr = acc_init.type.layout - else: - layout: ttgl.constexpr = default_blocked_layout(result.type.shape, ttgl.num_warps()) - result = ttgl.convert_layout(result, layout) - return result - - -@gluon.constexpr_function -def tl_dot_mmav5_supported(a_ty, b_ty, num_warps, input_precision, allow_tf32, max_num_imprecise_acc): - assert max_num_imprecise_acc is None, ("max_num_imprecise_acc only applies to Hopper warp_group_dot") - assert input_precision is None or allow_tf32 is None, ( - "Only one of input_precision and allow_tf32 can be specified") - if input_precision is None and (allow_tf32 or allow_tf32 is None): - input_precision = "tf32" - - M = a_ty.shape[0] - N = b_ty.shape[1] - K = a_ty.shape[1] - min_K = 256 // a_ty.element_ty.primitive_bitwidth - if a_ty.element_ty.is_int() or b_ty.element_ty.is_int(): - return False - if (min(a_ty.element_ty.primitive_bitwidth, b_ty.element_ty.primitive_bitwidth) >= 32 - and input_precision != "tf32"): - return False - return (num_warps in [4, 8] and len(a_ty.shape) == 2 and len(b_ty.shape) == 2 and K >= min_K and M >= 64 - and N >= 16) - - -@gluon.constexpr_function -def get_shared_memory_mma_layout(type, operand_index, allow_transpose, is_fp4_padded=False, force_transpose=False): - if not allow_transpose: - if operand_index == 1: - transposed = True - else: - transposed = False - if force_transpose: - transposed = not transposed - else: - transposed = operand_index == 1 - - shape = type.shape - swizzle_byte_width = 0 - ele_bit_width = type.element_ty.primitive_bitwidth - packing_factor = 2 if is_fp4_padded else 1 - - contig_dim_size_in_byte = ((shape[0] if transposed else shape[1]) * packing_factor * ele_bit_width // 8) - if contig_dim_size_in_byte >= 128 and contig_dim_size_in_byte % 128 == 0: - swizzle_byte_width = 128 - elif contig_dim_size_in_byte >= 64 and contig_dim_size_in_byte % 64 == 0: - swizzle_byte_width = 64 - elif contig_dim_size_in_byte >= 32 and contig_dim_size_in_byte % 32 == 0: - swizzle_byte_width = 32 - else: - swizzle_byte_width = 0 - - flatten_outer_dim = 1 - for dim in shape: - flatten_outer_dim *= dim - if len(shape) < 2 or flatten_outer_dim < 8: - swizzle_byte_width = 0 - return ttgl.NVMMASharedLayout( - swizzle_byte_width=swizzle_byte_width, - transposed=transposed, - element_bitwidth=ele_bit_width, - rank=len(shape), - fp4_padded=is_fp4_padded, - ) - - -@gluon.jit -def get_shared_memory_mma_operand(value, operand_index, allow_transpose, is_fp4_padded=False, force_transpose=False): - layout: ttgl.constexpr = get_shared_memory_mma_layout(value.type, operand_index, allow_transpose, is_fp4_padded, - force_transpose) - return ttgl.allocate_shared_memory(value.dtype, value.shape, layout, value) - - -@gluon.jit -def tl_dot_blackwell( - a, - b, - acc=None, - input_precision=None, - allow_tf32=None, - max_num_imprecise_acc=None, - out_dtype=ttgl.float32, -): - M: ttgl.constexpr = a.type.shape[0] - N: ttgl.constexpr = b.type.shape[1] - - allow_transpose = not a.type.element_ty.is_fp32() - a_smem = get_shared_memory_mma_operand(a, 0, allow_transpose) - b_smem = get_shared_memory_mma_operand(b, 1, allow_transpose) - - # MMA instruction shape - m: ttgl.constexpr = 128 if M >= 128 else 64 - n: ttgl.constexpr = 256 if N >= 256 else N - - acc_dtype: ttgl.constexpr = acc.dtype if acc is not None else out_dtype - col_stride: ttgl.constexpr = 32 // acc_dtype.primitive_bitwidth - acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout([m, n], col_stride=col_stride) - acc_tmem = allocate_tensor_memory(acc_dtype, [M, N], acc_tmem_layout) - tmem_reg_layout: ttgl.constexpr = acc_tmem.get_reg_layout() - if acc is not None: - acc_temp = ttgl.convert_layout(acc, tmem_reg_layout) - else: - acc_temp = ttgl.zeros([M, N], out_dtype, layout=tmem_reg_layout) - acc_tmem.store(acc_temp) - fence_async_shared() - bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) - mbarrier.init(bar, count=1) - tcgen05_mma(a_smem, b_smem, acc_tmem, use_acc=True) - tcgen05_commit(bar) - mbarrier.wait(bar, phase=0) - mbarrier.invalidate(bar) - - # Load back from TMEM using a register layout and convert to acc layout - out = acc_tmem.load() - ret_layout: ttgl.constexpr = default_blocked_layout([M, N], ttgl.num_warps()) - out = ttgl.convert_layout(out, ret_layout) - return out - - -@gluon.jit -def tl_dot( - a, - b, - acc=None, - input_precision=None, - allow_tf32=None, - max_num_imprecise_acc=None, - out_dtype=ttgl.float32, -): - target: ttgl.constexpr = current_target() - if _is_gfx1250(target): - return tl_dot_wmma(a, b, acc, out_dtype) - elif _is_cdna(target): - return tl_dot_mfma(a, b, acc, out_dtype) - else: - num_warps: ttgl.constexpr = ttgl.num_warps() - if tl_dot_mmav5_supported(a.type, b.type, num_warps, input_precision, allow_tf32, max_num_imprecise_acc): - return tl_dot_blackwell(a, b, acc, input_precision, allow_tf32, max_num_imprecise_acc, out_dtype) - else: - return tl_dot_mma_sync(a, b, acc, input_precision, out_dtype) - - -@gluon.constexpr_function -def tl_dot_scaled_mmav5_supported(a_ty, b_ty, num_warps): - M = a_ty.shape[0] - N = b_ty.shape[1] - K = a_ty.shape[1] - min_K = 256 // a_ty.element_ty.primitive_bitwidth - return (num_warps in [4, 8] and len(a_ty.shape) == 2 and len(b_ty.shape) == 2 and K >= min_K and M >= 128 - and N >= 16) - - -@gluon.constexpr_function -def get_swizzle_byte_width(bitwidth): - swizzle = min(bitwidth, 128) - swizzle = 0 if swizzle < 32 else swizzle - return swizzle - - -@gluon.constexpr_function -def get_int_type(bitwidth): - if bitwidth == 64: - return ttgl.int64 - elif bitwidth == 32: - return ttgl.int32 - elif bitwidth == 16: - return ttgl.int16 - elif bitwidth == 8: - return ttgl.int8 - else: - assert False, f"Unsupported bitwidth: {bitwidth}" - - -@gluon.jit -def tl_dot_decomposed_scale_to_16(scale, compute_type): - large_fp_type: ttgl.constexpr = ttgl.float32 if compute_type == ttgl.float16 else compute_type - int_width: ttgl.constexpr = large_fp_type.primitive_bitwidth - int_type: ttgl.constexpr = get_int_type(int_width) - - zexted = ttgl.cast(scale, int_type) - shift_value: ttgl.constexpr = large_fp_type.fp_mantissa_width - shl_res = zexted << shift_value - scale_fp = ttgl.cast(shl_res, large_fp_type, bitcast=True) - if large_fp_type != compute_type: - scale_fp = ttgl.cast(scale_fp, compute_type) - return scale_fp - - -@gluon.constexpr_function -def tl_dot_get_expand_dims_layout(scale_ty, num_warps, rank): - shape = scale_ty.shape.values + [1] - blocked = default_blocked_layout(shape, num_warps) - slice = ttgl.SliceLayout(rank, blocked) - return slice - - -@gluon.constexpr_function -def tl_dot_get_permute_order(rank, dim): - order = list(range(rank)) - order.insert(dim + 1, rank) - return order - - -@gluon.constexpr_function -def tl_dot_get_reshape_shape(scale_ty, dim): - shape = list(scale_ty.shape.values) - shape.pop() - shape[dim] *= 32 - return shape - - -@gluon.jit -def tl_dot_decomposed_broadcast_scale(scale, dim): - scale_ty: ttgl.constexpr = scale.type - rank: ttgl.constexpr = len(scale_ty.shape) - - num_warps: ttgl.constexpr = ttgl.num_warps() - slice_enc: ttgl.constexpr = tl_dot_get_expand_dims_layout(scale_ty, num_warps, rank) - scale = ttgl.convert_layout(scale, slice_enc) - expand_scale = scale.expand_dims(rank) - broadcast_scale = expand_scale.broadcast_to(scale.type.shape + (32, )) - permute_order: ttgl.constexpr = tl_dot_get_permute_order(rank, dim) - transposed_scale = broadcast_scale.permute(permute_order) - reshape_shape: ttgl.constexpr = tl_dot_get_reshape_shape(broadcast_scale.type, dim) - return transposed_scale.reshape(reshape_shape) - - -@gluon.constexpr_function -def tl_dot_decomposed_get_transposed_order(rank): - assert rank >= 2 - order = list(range(rank - 2)) - order += [rank - 1, rank - 2] - return order - - -@gluon.jit -def tl_dot_decomposed_extend_and_broadcast_scale(v, scale, compute_type, operand_index): - rank: ttgl.constexpr = len(v.type.shape) - k_dim: ttgl.constexpr = rank - 1 if operand_index == 0 else rank - 2 - - if operand_index == 1: - order: ttgl.constexpr = tl_dot_decomposed_get_transposed_order(rank) - scale = ttgl.permute(scale, order) - - scale16 = tl_dot_decomposed_scale_to_16(scale, compute_type) - reshape_scale = tl_dot_decomposed_broadcast_scale(scale16, k_dim) - return ttgl.convert_layout(reshape_scale, v.type.layout), scale - - -@gluon.jit -def tl_dot_decomposed_mask_nan(mxfp, scale, fast_math): - ttgl.static_assert(fast_math, "TODO: support non-fast-math") - return mxfp - - -@gluon.jit -def tl_dot_decomposed_scale_arg(v, scale, arg_format, operand_index, compute_type, fast_math): - is_fp4: ttgl.constexpr = arg_format == "e2m1" - rank: ttgl.constexpr = len(v.type.shape) - k_dim: ttgl.constexpr = rank - 1 if operand_index == 0 else rank - 2 - - if is_fp4: - v = ttgl.fp4_to_fp(v, compute_type, k_dim) - else: - v = ttgl.cast(v, compute_type) - if scale is None: - return v - else: - reshape_scale, scale = tl_dot_decomposed_extend_and_broadcast_scale(v, scale, compute_type, operand_index) - mxfp = ttgl.mul(v, reshape_scale) - return tl_dot_decomposed_mask_nan(mxfp, scale, fast_math) - - -@gluon.jit -def tl_dot_scaled( - lhs, - lhs_scale, - lhs_format, - rhs, - rhs_scale, - rhs_format, - acc=None, - fast_math=False, - lhs_k_pack=True, - rhs_k_pack=True, - out_dtype=ttgl.float32, -): - if (_is_nvidia(current_target()) and tl_dot_scaled_mmav5_supported(lhs.type, rhs.type, ttgl.num_warps()) - and lhs_scale is not None and rhs_scale is not None): - return tl_dot_scaled_blackwell( - lhs, - lhs_scale, - lhs_format, - rhs, - rhs_scale, - rhs_format, - acc, - fast_math, - lhs_k_pack, - rhs_k_pack, - out_dtype, - ) - else: - return tl_dot_decomposed_block_scales( - lhs, - lhs_scale, - lhs_format, - rhs, - rhs_scale, - rhs_format, - acc, - fast_math, - lhs_k_pack, - rhs_k_pack, - out_dtype, - ) - - -@gluon.jit -def tl_dot_decomposed_block_scales( - lhs, - lhs_scale, - lhs_format, - rhs, - rhs_scale, - rhs_format, - acc=None, - fast_math=False, - lhs_k_pack=True, - rhs_k_pack=True, - out_dtype=ttgl.float32, -): - if lhs_scale is None and rhs_scale is not None: - lhs_trans = tl_trans(lhs) - rhs_trans = tl_trans(rhs) - if acc is not None: - orig_layout: ttgl.constexpr = acc.type.layout - acc = tl_trans(acc) - result = tl_dot_scaled( - rhs_trans, - rhs_scale, - rhs_format, - lhs_trans, - lhs_scale, - lhs_format, - acc, - fast_math, - lhs_k_pack, - rhs_k_pack, - out_dtype, - ) - result = tl_trans(result) - if acc is not None: - result = ttgl.convert_layout(result, orig_layout) - return result - else: - ttgl.static_assert(not (not lhs_k_pack or not rhs_k_pack), "TODO: support m/n packed formats") - compute_type: ttgl.constexpr = (ttgl.float16 if - (lhs_format == "fp16" or rhs_format == "fp16") else ttgl.bfloat16) - - scale_a = tl_dot_decomposed_scale_arg(lhs, lhs_scale, lhs_format, 0, compute_type, fast_math) - scale_b = tl_dot_decomposed_scale_arg(rhs, rhs_scale, rhs_format, 1, compute_type, fast_math) - - return tl_dot(scale_a, scale_b, acc, out_dtype=out_dtype) - - -@gluon.jit -def tl_dot_scaled_blackwell( - lhs, - lhs_scale, - lhs_format, - rhs, - rhs_scale, - rhs_format, - acc=None, - fast_math=False, - lhs_k_pack=True, - rhs_k_pack=True, - out_dtype=ttgl.float32, -): - is_a_fp4: ttgl.constexpr = lhs_format == "e2m1" - is_b_fp4: ttgl.constexpr = rhs_format == "e2m1" - - mixed_prec: ttgl.constexpr = lhs_format != rhs_format - is_a_mixed_prec_fp4: ttgl.constexpr = mixed_prec and is_a_fp4 - is_b_mixed_prec_fp4: ttgl.constexpr = mixed_prec and not is_a_fp4 and is_b_fp4 - - is_mmav5_fp4_padded_a: ttgl.constexpr = is_a_mixed_prec_fp4 or not lhs_k_pack - is_mmav5_fp4_padded_b: ttgl.constexpr = is_b_mixed_prec_fp4 or not rhs_k_pack - - a_smem = get_shared_memory_mma_operand( - lhs, - 0, - allow_transpose=not is_a_fp4, - is_fp4_padded=is_mmav5_fp4_padded_a, - force_transpose=not lhs_k_pack, - ) - b_smem = get_shared_memory_mma_operand( - rhs, - 1, - allow_transpose=not is_b_fp4, - is_fp4_padded=is_mmav5_fp4_padded_b, - force_transpose=not rhs_k_pack, - ) - - M: ttgl.constexpr = lhs.type.shape[0] - N: ttgl.constexpr = rhs.type.shape[1] - - m: ttgl.constexpr = 128 - n: ttgl.constexpr = 256 if N >= 256 else N - - acc_dtype: ttgl.constexpr = acc.dtype if acc is not None else out_dtype - col_stride: ttgl.constexpr = 32 // acc_dtype.primitive_bitwidth - acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout([m, n], col_stride=col_stride) - acc_tmem = allocate_tensor_memory(acc_dtype, [M, N], acc_tmem_layout) - tmem_reg_layout: ttgl.constexpr = acc_tmem.get_reg_layout() - if acc is not None: - acc_temp = ttgl.convert_layout(acc, tmem_reg_layout) - else: - acc_temp = ttgl.zeros([M, N], out_dtype, layout=tmem_reg_layout) - acc_tmem.store(acc_temp) - fence_async_shared() - - bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) - mbarrier.init(bar, count=1) - scale_layout: ttgl.constexpr = TensorMemoryScalesLayout() - a_scale_tmem = allocate_tensor_memory(lhs_scale.dtype, lhs_scale.shape, scale_layout) - b_scale_tmem = allocate_tensor_memory(rhs_scale.dtype, rhs_scale.shape, scale_layout) - scale_layout_reg_lhs: ttgl.constexpr = a_scale_tmem.get_reg_layout() - scale_layout_reg_rhs: ttgl.constexpr = b_scale_tmem.get_reg_layout() - lhs_scale = ttgl.convert_layout(lhs_scale, scale_layout_reg_lhs) - rhs_scale = ttgl.convert_layout(rhs_scale, scale_layout_reg_rhs) - a_scale_tmem.store(lhs_scale) - b_scale_tmem.store(rhs_scale) - - tcgen05_mma_scaled(a_smem, b_smem, acc_tmem, a_scale_tmem, b_scale_tmem, lhs_format, rhs_format, use_acc=True) - tcgen05_commit(bar) - mbarrier.wait(bar, phase=0) - mbarrier.invalidate(bar) - # Load back from TMEM using a register layout and convert to acc layout - out = acc_tmem.load() - ret_layout: ttgl.constexpr = default_blocked_layout([M, N], ttgl.num_warps()) - out = ttgl.convert_layout(out, ret_layout) - return out - - -@gluon.constexpr_function -def get_num_threads_per_warp(target=None) -> ttgl.constexpr: - if target is None: - target = current_target() - if target is not None and target.backend == "hip": - gfx_major = int(target.arch[3:-2]) - return ttgl.constexpr(32 if gfx_major >= 10 else 64) - return ttgl.constexpr(32) - - -@gluon.jit -def get_num_threads_per_program(): - return ttgl.num_warps() * get_num_threads_per_warp(current_target()) - - -@gluon.constexpr_function -def default_blocked_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr, target=None) -> ttgl.constexpr: - rank = len(shape) - # 1 element per thread for all dimensions - size_per_thread = [1] * rank - # Distribute threads per warp across dimensions (simple heuristic: last-fastest) - threads_per_warp = [1] * rank - # TODO: pick a better layout based on shape. Using this allows to not have to convert layout when broadcasting but may blow up register pressure. - threads_per_warp[rank - 1] = get_num_threads_per_warp(target) - # Use provided num_warps to distribute warps per CTA (put all on first dim) - warps_per_cta = [1] * rank - warps_per_cta[0] = num_warps - # Natural order [rank-1, rank-2, ..., 0] - order = list(range(rank - 1, -1, -1)) - return ttgl.BlockedLayout( - size_per_thread=size_per_thread, - threads_per_warp=threads_per_warp, - warps_per_cta=warps_per_cta, - order=order, - ) - - -# ---- architecture detection ---- - - -@gluon.constexpr_function -def _is_nvidia(target=None): - return target is None or target.backend == "cuda" - - -@gluon.constexpr_function -def _is_gfx1250(target=None): - return target is not None and target.arch == "gfx1250" - - -@gluon.constexpr_function -def _is_cdna(target=None): - return target is not None and target.arch in ("gfx942", "gfx950") - - -@gluon.constexpr_function -def _cdna_version(target=None): - """Returns 3 for gfx942, 4 for gfx950.""" - return 4 if target is not None and target.arch == "gfx950" else 3 - - -# ---- AMD WMMA layout helpers (gfx1250) ---- - - -@gluon.constexpr_function -def compute_warp_bases(num_warps): - """Distribute warps across M/N: first bit to N, rest to M.""" - n_bits = int(math.log2(num_warps)) - if n_bits == 0: - return [] - warp_bases = [[0, 1]] - for i in range(n_bits - 1): - warp_bases.append([1 << i, 0]) - return warp_bases - - -@gluon.constexpr_function -def get_wmma_layout(shape, num_warps): - warp_bases = compute_warp_bases(num_warps) - return ttgl.amd.AMDWMMALayout(3, True, warp_bases, [], [16, 16, 32]) - - -@gluon.constexpr_function -def get_wmma_k_width(a_ty, b_ty): - min_bitwidth = min(a_ty.element_ty.primitive_bitwidth, b_ty.element_ty.primitive_bitwidth) - return max(128 // min_bitwidth, 1) - - -# ---- AMD MFMA layout helpers (cdna3/cdna4) ---- - - -@gluon.constexpr_function -def get_mfma_instr_k(element_bitwidth, target=None): - """K dimension of the MFMA instruction for [32, 32, K].""" - k_bits = 128 if _cdna_version(target) == 3 else 256 - return k_bits // element_bitwidth - - -@gluon.constexpr_function -def get_mfma_layout(num_warps, element_bitwidth, target=None): - instr_k = get_mfma_instr_k(element_bitwidth, target) - return ttgl.amd.AMDMFMALayout( - version=_cdna_version(target), - instr_shape=[32, 32, instr_k], - transposed=True, - warps_per_cta=[num_warps, 1], - ) - - -@gluon.constexpr_function -def get_mfma_k_width(a_ty, b_ty, target=None): - min_bitwidth = min(a_ty.element_ty.primitive_bitwidth, b_ty.element_ty.primitive_bitwidth) - instr_k = get_mfma_instr_k(min_bitwidth, target) - return instr_k // 2 - - -# ---- AMD dot paths ---- - - -@gluon.jit -def tl_dot_wmma(a, b, acc, out_dtype): - """gfx1250 WMMA path.""" - M: ttgl.constexpr = a.type.shape[0] - N: ttgl.constexpr = b.type.shape[1] - num_warps: ttgl.constexpr = ttgl.num_warps() - - wmma_layout: ttgl.constexpr = get_wmma_layout([M, N], num_warps) - k_width: ttgl.constexpr = get_wmma_k_width(a.type, b.type) - a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=wmma_layout, k_width=k_width) - b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=wmma_layout, k_width=k_width) - - a = ttgl.convert_layout(a, a_layout) - b = ttgl.convert_layout(b, b_layout) - - if acc is not None: - accumulator = ttgl.convert_layout(acc, wmma_layout) - else: - accumulator = ttgl.zeros([M, N], out_dtype, layout=wmma_layout) - - result = amd_wmma(a, b, accumulator) - - if acc is not None: - ret_layout: ttgl.constexpr = acc.type.layout - else: - ret_layout: ttgl.constexpr = default_blocked_layout(result.type.shape, num_warps) - return ttgl.convert_layout(result, ret_layout) - - -@gluon.jit -def tl_dot_mfma(a, b, acc, out_dtype): - """CDNA3/CDNA4 MFMA path.""" - M: ttgl.constexpr = a.type.shape[0] - N: ttgl.constexpr = b.type.shape[1] - num_warps: ttgl.constexpr = ttgl.num_warps() - min_bitwidth: ttgl.constexpr = min(a.type.element_ty.primitive_bitwidth, b.type.element_ty.primitive_bitwidth) - target: ttgl.constexpr = current_target() - - mfma_layout: ttgl.constexpr = get_mfma_layout(num_warps, min_bitwidth, target) - k_width: ttgl.constexpr = get_mfma_k_width(a.type, b.type, target) - a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=k_width) - b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=k_width) - - a = ttgl.convert_layout(a, a_layout) - b = ttgl.convert_layout(b, b_layout) - - if acc is not None: - accumulator = ttgl.convert_layout(acc, mfma_layout) - else: - accumulator = ttgl.zeros([M, N], out_dtype, layout=mfma_layout) - - result = amd_mfma(a, b, accumulator) - - if acc is not None: - ret_layout: ttgl.constexpr = acc.type.layout - else: - ret_layout: ttgl.constexpr = default_blocked_layout(result.type.shape, num_warps) - return ttgl.convert_layout(result, ret_layout) - - -# ---- AMD TDM tensor descriptors (gfx1250 only) ---- - - -@gluon.constexpr_function -def get_default_tdm_layout(*block_shape): - block_shape = list(block_shape) - return ttgl.PaddedSharedLayout.with_identity_for( - [[block_shape[-1], 4]], - block_shape, - list(range(len(block_shape) - 1, -1, -1)), - ) - - -@tl.core._aggregate -class AMDTensorDescriptorArgs: - """Wraps a real TDM descriptor alongside the original base pointer. - - The base_ptr is needed by gather/scatter to recreate the descriptor with a different - block_shape -- Triton uses block_shape=[1, N] but TDM hardware requires [num_indices, N]. - Shape, strides, and block_shape are read from desc (type metadata gives plain Python ints - for block_shape, tuples for shape/strides).""" - desc: amd_tdm.tensor_descriptor - base_ptr: tl.core.tensor - - -@gluon.jit -def tl_obj_load_amd(desc, offsets): - smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout) - amd_tdm.async_load(desc, offsets, smem) - amd_tdm.async_wait(0) - ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) - return smem.load(ret_layout) - - -@gluon.jit -def tl_obj_store_amd(desc, offsets, value): - smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value) - amd_tdm.async_store(desc, offsets, smem) - amd_tdm.async_wait(0) - - -# ---- obj dispatch (routes desc.load/store/gather/scatter to TMA or TDM) ---- - - -@gluon.jit -def tl_obj_store(obj, offsets, value): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - return tl_store_tensor_descriptor(obj, offsets, value) - elif isinstance(obj, AMDTensorDescriptorArgs): - tl_obj_store_amd(obj.desc, offsets, value) - elif isinstance(obj, amd_tdm.tensor_descriptor): - tl_obj_store_amd(obj, offsets, value) - else: - return obj.store(offsets, value) - - -@gluon.jit -def tl_obj_load(obj, offsets): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - return tl_load_tensor_descriptor(obj, offsets) - elif isinstance(obj, AMDTensorDescriptorArgs): - return tl_obj_load_amd(obj.desc, offsets) - elif isinstance(obj, amd_tdm.tensor_descriptor): - return tl_obj_load_amd(obj, offsets) - else: - return obj.load(offsets) - - -@gluon.jit -def tl_obj_gather_amd(desc_args, x_offsets, y_offset): - # Triton creates gather descriptors with block_shape=[1, block_n], but TDM hardware - # operates on the full batch, requiring block_shape=[num_indices, block_n]. - NUM_IDX: ttgl.constexpr = x_offsets.shape[0] - BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] - smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) - gather_block_shape: ttgl.constexpr = [NUM_IDX, BLOCK_N] - gather_desc = amd_tdm.make_tensor_descriptor(desc_args.base_ptr, desc_args.desc.shape, desc_args.desc.strides, - gather_block_shape, smem_layout) - num_warps: ttgl.constexpr = ttgl.num_warps() - gather_shape: ttgl.constexpr = gather_desc.block_shape - idx_base: ttgl.constexpr = ttgl.BlockedLayout([gather_shape[0], 1], - [1, get_num_threads_per_warp(current_target())], [1, num_warps], - [1, 0]) - idx_layout: ttgl.constexpr = ttgl.SliceLayout(1, idx_base) - x_offsets = ttgl.convert_layout(x_offsets, idx_layout) - alloc = ttgl.allocate_shared_memory(desc_args.desc.dtype, list(gather_shape), smem_layout) - y_off = ttgl.to_tensor(y_offset) - amd_tdm.async_gather(gather_desc, x_offsets, y_off, alloc) - amd_tdm.async_wait(0) - ret_layout: ttgl.constexpr = default_blocked_layout(list(gather_shape), num_warps, current_target()) - out = alloc.load(ret_layout) - return out - - -@gluon.jit -def tl_obj_scatter_amd(desc_args, value, x_offsets, y_offset): - # See tl_obj_gather_amd for why the descriptor is recreated with a different block_shape. - NUM_IDX: ttgl.constexpr = x_offsets.shape[0] - BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] - smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) - scatter_block_shape: ttgl.constexpr = [NUM_IDX, BLOCK_N] - scatter_desc = amd_tdm.make_tensor_descriptor(desc_args.base_ptr, desc_args.desc.shape, desc_args.desc.strides, - scatter_block_shape, smem_layout) - num_warps: ttgl.constexpr = ttgl.num_warps() - scatter_shape: ttgl.constexpr = scatter_desc.block_shape - idx_base: ttgl.constexpr = ttgl.BlockedLayout([scatter_shape[0], 1], - [1, get_num_threads_per_warp(current_target())], [1, num_warps], - [1, 0]) - idx_layout: ttgl.constexpr = ttgl.SliceLayout(1, idx_base) - x_offsets = ttgl.convert_layout(x_offsets, idx_layout) - alloc = ttgl.allocate_shared_memory(desc_args.desc.dtype, list(scatter_shape), smem_layout, value) - y_off = ttgl.to_tensor(y_offset) - amd_tdm.async_scatter(scatter_desc, x_offsets, y_off, alloc) - amd_tdm.async_wait(0) - - -@gluon.jit -def tl_obj_gather(obj, x_offsets, y_offset): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - desc = obj - desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] - alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout) - bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) - mbarrier.init(bar, count=1) - x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( - 0, - ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), - ) - x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) - mbarrier.expect(bar, x_offsets.shape[0] * obj.block_type.nbytes) - tma_blackwell.async_gather(desc, x_offsets, y_offset, bar, alloc) - mbarrier.wait(bar, phase=0) - mbarrier.invalidate(bar) - # Load from shared memory into a register tensor using a reasonable default layout - ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) - out = alloc.load(ret_layout) - return out - elif isinstance(obj, AMDTensorDescriptorArgs): - return tl_obj_gather_amd(obj, x_offsets, y_offset) - else: - return obj.gather(x_offsets, y_offset) - - -@gluon.jit -def tl_obj_scatter(obj, value, x_offsets, y_offset): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - desc = obj - desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] - alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout, value) - fence_async_shared() - x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( - 0, - ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), - ) - x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) - tma_blackwell.async_scatter(desc, x_offsets, y_offset, alloc) - tma.store_wait(0) - elif isinstance(obj, AMDTensorDescriptorArgs): - tl_obj_scatter_amd(obj, value, x_offsets, y_offset) - else: - obj.scatter(value, x_offsets, y_offset) - - -@gluon.jit -def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option: ttgl.constexpr = "zero"): - target: ttgl.constexpr = current_target() - if _is_gfx1250(target): - layout: ttgl.constexpr = get_default_tdm_layout(*block_shape) - desc = amd_tdm.make_tensor_descriptor(base, shape, strides, block_shape, layout) - return AMDTensorDescriptorArgs(desc, base) - else: - layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for(block_shape, base.dtype.element_ty) - return tma.make_tensor_descriptor(base, shape, strides, block_shape, layout, padding_option) - - -@gluon.jit -def tl_store_tensor_descriptor(desc, offsets, value): - alloc = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value) - fence_async_shared() - tma.async_copy_shared_to_global(desc, offsets, alloc) - tma.store_wait(0) - alloc._keep_alive() - - -@gluon.jit -def tl_load_tensor_descriptor(desc, offsets): - smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout) - bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) - mbarrier.init(bar, count=1) - # Issue async copy from global (descriptor) to shared memory and wait for completion - mbarrier.expect(bar, desc.block_type.nbytes) - tma.async_copy_global_to_shared(desc, offsets, bar, smem) - mbarrier.wait(bar, phase=0) - mbarrier.invalidate(bar) - # Load from shared memory into a register tensor using a reasonable default layout - ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) - out = smem.load(ret_layout) - return out - - -@gluon.jit -def tl_arange(start: ttgl.constexpr, stop: ttgl.constexpr = None): - layout: ttgl.constexpr = default_blocked_layout([stop - start], ttgl.num_warps()) - return ttgl.arange(start, stop, layout=layout) - - -@gluon.jit -def tl_full(shape, value, dtype=None): - layout: ttgl.constexpr = default_blocked_layout(shape, ttgl.num_warps()) - return ttgl.full(shape, value, dtype, layout=layout) - - -@gluon.jit -def tl_trans(value, *dims): - return value.trans(*dims) - - -@gluon.constexpr_function -def cat_permute_order(rank, dim): - order = list(range(rank)) - order.insert(dim, rank) - return order - - -@gluon.constexpr_function -def cat_result_shape(input_shape, dim): - result_shape = list(input_shape) - result_shape[dim] *= 2 - return result_shape - - -@gluon.jit -def tl_cat(input, other, can_reorder=False, dim=0): - # Join introduces a new minor dim; move it before the concat dim and merge. - c = ttgl.join(input, other) - order: ttgl.constexpr = cat_permute_order(len(input.shape), dim) - c = ttgl.permute(c, order) - shape: ttgl.constexpr = cat_result_shape(input.shape, dim) - c = ttgl.reshape(c, shape) - return reset_to_default_layout(c) - - -@gluon.jit -def reset_to_default_layout(value): - ty: ttgl.constexpr = value.type - if isinstance(ty, ttgl.tuple_type): - out = () - for i in ttgl.static_range(len(value)): - r = ttgl.convert_layout(value[i], layout=default_blocked_layout(value[i].type.shape, ttgl.num_warps())) - out = out + (r, ) - return out - elif isinstance(value, ttgl.tensor) and isinstance(value.type, ttgl.distributed_type): - layout: ttgl.constexpr = default_blocked_layout(ty.shape, ttgl.num_warps()) - return ttgl.convert_layout(value, layout=layout) - else: - return value - - -@gluon.constexpr_function -def get_split_src_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr, target=None) -> ttgl.constexpr: - rank = len(shape) - size_per_thread = [1 if i != rank - 1 else 2 for i in range(rank)] - # Distribute threads per warp across dimensions (simple heuristic: last-fastest) - threads_per_warp = [1 for _ in range(rank)] - remaining_threads = get_num_threads_per_warp(target) - for dim in range(rank - 2, -1, -1): - threads_per_warp[dim] = min(shape[dim], remaining_threads) - remaining_threads = remaining_threads // threads_per_warp[dim] - # Use provided num_warps to distribute warps per CTA (put all on first dim) - warps_per_cta = [1 for _ in range(rank)] - warps_per_cta[0] = num_warps - # Natural order [rank-1, rank-2, ..., 0] - order = list(range(rank - 1, -1, -1)) - return ttgl.BlockedLayout( - size_per_thread=size_per_thread, - threads_per_warp=threads_per_warp, - warps_per_cta=warps_per_cta, - order=order, - ) - - -@gluon.jit -def set_split_src_layout(value): - layout: ttgl.constexpr = get_split_src_layout(value.type.shape, ttgl.num_warps()) - return ttgl.convert_layout(value, layout=layout) - - -def convert_host_descriptor(desc): - - def torch_dtype_to_triton(dtype): - import torch - - if dtype == torch.float8_e5m2: - return ttgl.float8e5 - if dtype == torch.float8_e4m3fn: - return ttgl.float8e4nv - return getattr(ttgl, str(dtype).split(".")[1]) - - from triton.tools.tensor_descriptor import TensorDescriptor - - assert isinstance(desc, TensorDescriptor) - block_shape = desc.block_shape - dtype = desc.base.dtype - tensor = desc.base - - target = current_target() - if target is not None and target.backend == "hip" and target.arch == "gfx1250": - layout = get_default_tdm_layout(*block_shape) - return gluon.amd.gfx1250.TensorDescriptor(tensor, list(desc.shape), list(desc.strides), block_shape, layout) - - layout = ttgl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(dtype)) - return gluon.nvidia.hopper.TensorDescriptor(tensor, desc.shape, desc.strides, block_shape, layout) - - -@gluon.constexpr_function -def build_expand_dims_layout(shape, expand_dims, num_warps): - if isinstance(shape, ttgl.tuple): - shape = shape.values - assert isinstance(shape, list), (f"expected shape to be a list, got {shape} which is {type(shape)}") - parent_shape = list(shape) - for dim in expand_dims: - parent_shape.insert(dim, 1) - layout = default_blocked_layout(parent_shape, num_warps) - for dim in reversed(expand_dims): - layout = ttgl.SliceLayout(dim=dim, parent=layout) - return layout - - -@gluon.jit -def convert_to_expand_dims_layout(value, expand_dims: list[int]): - layout: ttgl.constexpr = build_expand_dims_layout(value.shape, expand_dims, ttgl.num_warps()) - return ttgl.convert_layout(value, layout) +from triton.tools.triton_to_gluon_translator.common_helpers import * # noqa: F401,F403 +from triton.tools.triton_to_gluon_translator.nvidia_helpers import * # noqa: F401,F403 +from triton.tools.triton_to_gluon_translator.amd_helpers import * # noqa: F401,F403 From 2d05693b1e0cfe1a78047462d07a8b3a2a11a9b5 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 15 Apr 2026 00:38:17 -0700 Subject: [PATCH 23/26] pre-commit run --- .../triton/tools/triton_to_gluon_translator/amd_helpers.py | 1 - .../tools/triton_to_gluon_translator/common_helpers.py | 4 ---- .../tools/triton_to_gluon_translator/nvidia_helpers.py | 5 ++--- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/python/triton/tools/triton_to_gluon_translator/amd_helpers.py b/python/triton/tools/triton_to_gluon_translator/amd_helpers.py index 6b78f2c5bee7..091051f1f538 100644 --- a/python/triton/tools/triton_to_gluon_translator/amd_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/amd_helpers.py @@ -18,7 +18,6 @@ tl_trans, ) - # ---- architecture detection ---- diff --git a/python/triton/tools/triton_to_gluon_translator/common_helpers.py b/python/triton/tools/triton_to_gluon_translator/common_helpers.py index 6630438d5eb4..7442a9287332 100644 --- a/python/triton/tools/triton_to_gluon_translator/common_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/common_helpers.py @@ -1,8 +1,5 @@ # type: ignore -import math - -import triton.language as tl from triton.experimental import gluon from triton.experimental.gluon import language as ttgl @@ -10,7 +7,6 @@ # TODO: fix this by pulling imports into the generated file. from triton.language.target_info import current_target # noqa: F401 - # ---- layout utilities ---- diff --git a/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py b/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py index e005f0e64896..381451eced92 100644 --- a/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py @@ -22,7 +22,6 @@ tl_trans, ) - # ---- NVIDIA MMA sync (Ampere) ---- @@ -358,8 +357,8 @@ def tl_dot_scaled( rhs_k_pack=True, out_dtype=ttgl.float32, ): - if (tl_dot_scaled_mmav5_supported(lhs.type, rhs.type, ttgl.num_warps()) - and lhs_scale is not None and rhs_scale is not None): + if (tl_dot_scaled_mmav5_supported(lhs.type, rhs.type, ttgl.num_warps()) and lhs_scale is not None + and rhs_scale is not None): return tl_dot_scaled_blackwell( lhs, lhs_scale, From 6c2eb1ff7cc8a570b8626a3b34a856e9b166ec86 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 15 Apr 2026 00:38:17 -0700 Subject: [PATCH 24/26] Split translator_helpers into target-specific helper modules Split the monolithic translator_helpers.py into: - target.py: TranslatorTarget enum (hardware abstraction) - common_helpers.py: vendor-neutral utilities (layouts, portable ops) - nvidia_helpers.py: NVIDIA-specific helpers (TMA, mbarrier, Blackwell) - amd_helpers.py: AMD-specific helpers (TDM, WMMA, MFMA) Each target module re-exports common helpers via star import so the generated kernel sees a single unified `helpers` namespace. The TranslatorTarget.helpers_module property selects which module to import, so translated kernels no longer pull in unrelated hardware modules. translator_helpers.py is kept as a backward-compat re-export shim. Made-with: Cursor --- .../test/unit/tools/test_triton_to_gluon.py | 2 +- .../slice_kernel.py | 45 +----------------- .../triton_to_gluon_translator/target.py | 46 +++++++++++++++++++ .../triton_to_gluon_translator/translator.py | 2 +- 4 files changed, 49 insertions(+), 46 deletions(-) create mode 100644 python/triton/tools/triton_to_gluon_translator/target.py diff --git a/python/test/unit/tools/test_triton_to_gluon.py b/python/test/unit/tools/test_triton_to_gluon.py index feac1d39ee90..474708b0b249 100644 --- a/python/test/unit/tools/test_triton_to_gluon.py +++ b/python/test/unit/tools/test_triton_to_gluon.py @@ -8,7 +8,7 @@ from triton.tools.triton_to_gluon_translator.translator import convert_triton_to_gluon from triton.tools.triton_to_gluon_translator.translator_helpers import convert_host_descriptor -from triton.tools.triton_to_gluon_translator.slice_kernel import TranslatorTarget +from triton.tools.triton_to_gluon_translator.target import TranslatorTarget from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda, is_hip_cdna4, is_hip_gfx1250, is_hip_cdna3_or_newer from triton.language.target_info import current_target diff --git a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py index 444f5c24cb34..cd40f4921047 100644 --- a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py +++ b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py @@ -13,7 +13,6 @@ from collections import OrderedDict from collections.abc import Sequence from dataclasses import dataclass, field -from enum import Enum from pathlib import Path from types import BuiltinFunctionType, FunctionType, ModuleType from typing import Any, Callable, TypeAlias @@ -27,53 +26,11 @@ from triton.tools.triton_to_gluon_translator.ordered_set import ordered_set from triton.tools.triton_to_gluon_translator.scoped_dict import scoped_dict from triton.tools.triton_to_gluon_translator.stable_toposort import stable_toposort +from triton.tools.triton_to_gluon_translator.target import TranslatorTarget logger = logging.getLogger(__name__) -class TranslatorTarget(str, Enum): - """Target architecture for the Triton-to-Gluon translator. - - Known targets are listed as explicit members for discoverability. - Unknown ``gfx*`` strings are accepted via ``_missing_()`` so that - new AMD architectures work without adding an enum member. - """ - - NVIDIA = "nvidia" - # AMD targets currently exercised by the translator test suite: - GFX1250 = "gfx1250" - GFX942 = "gfx942" - GFX950 = "gfx950" - - @classmethod - def _missing_(cls, value: object) -> "TranslatorTarget | None": - """Allow any ``gfx*`` string as a valid AMD target.""" - if isinstance(value, str) and value.startswith("gfx"): - obj = str.__new__(cls, value) - obj._value_ = value - return obj - return None - - @property - def is_amd(self) -> bool: - return self != TranslatorTarget.NVIDIA - - @property - def tensor_descriptor_import(self) -> str: - """Return the import statement for the target's tensor descriptor module.""" - if self.is_amd: - return "from triton.experimental.gluon.language.amd.gfx1250.tdm import tensor_descriptor" - return "from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor" - - @property - def helpers_module(self) -> str: - """Return the helpers module path for this target.""" - base = "triton.tools.triton_to_gluon_translator" - if self.is_amd: - return f"{base}.amd_helpers" - return f"{base}.nvidia_helpers" - - @dataclass class GlobalVariable: name: str diff --git a/python/triton/tools/triton_to_gluon_translator/target.py b/python/triton/tools/triton_to_gluon_translator/target.py new file mode 100644 index 000000000000..f4c6c03dde1e --- /dev/null +++ b/python/triton/tools/triton_to_gluon_translator/target.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from enum import Enum + + +class TranslatorTarget(str, Enum): + """Target architecture for the Triton-to-Gluon translator. + + Known targets are listed as explicit members for discoverability. + Unknown ``gfx*`` strings are accepted via ``_missing_()`` so that + new AMD architectures work without adding an enum member. + """ + + NVIDIA = "nvidia" + # AMD targets currently exercised by the translator test suite: + GFX1250 = "gfx1250" + GFX942 = "gfx942" + GFX950 = "gfx950" + + @classmethod + def _missing_(cls, value: object) -> "TranslatorTarget | None": + """Allow any ``gfx*`` string as a valid AMD target.""" + if isinstance(value, str) and value.startswith("gfx"): + obj = str.__new__(cls, value) + obj._value_ = value + return obj + return None + + @property + def is_amd(self) -> bool: + return self != TranslatorTarget.NVIDIA + + @property + def tensor_descriptor_import(self) -> str: + """Return the import statement for the target's tensor descriptor module.""" + if self.is_amd: + return "from triton.experimental.gluon.language.amd.gfx1250.tdm import tensor_descriptor" + return "from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor" + + @property + def helpers_module(self) -> str: + """Return the helpers module path for this target.""" + base = "triton.tools.triton_to_gluon_translator" + if self.is_amd: + return f"{base}.amd_helpers" + return f"{base}.nvidia_helpers" diff --git a/python/triton/tools/triton_to_gluon_translator/translator.py b/python/triton/tools/triton_to_gluon_translator/translator.py index 12fcae55205a..4e09e4679036 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator.py +++ b/python/triton/tools/triton_to_gluon_translator/translator.py @@ -14,7 +14,6 @@ GlobalValue, ReferenceRewriter, RewriteFn, - TranslatorTarget, add_sugar_rewrites, find_references, get_base_value, @@ -22,6 +21,7 @@ mangle_reference_names, parse_expr, ) +from triton.tools.triton_to_gluon_translator.target import TranslatorTarget from triton.tools.triton_to_gluon_translator.stable_toposort import stable_toposort From 96f20746aefed9781d38d4a91372c07637347105 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 15 Apr 2026 12:49:42 -0700 Subject: [PATCH 25/26] Remove isinstance dispatch from NVIDIA obj helpers The gluon JIT compiler evaluates all branches at compile time. With only NVIDIA types in scope, the else branch fails because tensor_descriptor lacks .store()/.load() methods. Since NVIDIA descriptors are always TMA type, call the TMA functions directly. Made-with: Cursor --- .../nvidia_helpers.py | 72 ++++++++----------- 1 file changed, 30 insertions(+), 42 deletions(-) diff --git a/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py b/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py index 381451eced92..8791e12df1ab 100644 --- a/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py @@ -425,60 +425,48 @@ def tl_load_tensor_descriptor(desc, offsets): @gluon.jit def tl_obj_store(obj, offsets, value): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - return tl_store_tensor_descriptor(obj, offsets, value) - else: - return obj.store(offsets, value) + tl_store_tensor_descriptor(obj, offsets, value) @gluon.jit def tl_obj_load(obj, offsets): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - return tl_load_tensor_descriptor(obj, offsets) - else: - return obj.load(offsets) + return tl_load_tensor_descriptor(obj, offsets) @gluon.jit def tl_obj_gather(obj, x_offsets, y_offset): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - desc = obj - desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] - alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout) - bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) - mbarrier.init(bar, count=1) - x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( - 0, - ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), - ) - x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) - mbarrier.expect(bar, x_offsets.shape[0] * obj.block_type.nbytes) - tma_blackwell.async_gather(desc, x_offsets, y_offset, bar, alloc) - mbarrier.wait(bar, phase=0) - mbarrier.invalidate(bar) - ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) - out = alloc.load(ret_layout) - return out - else: - return obj.gather(x_offsets, y_offset) + desc = obj + desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] + alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( + 0, + ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), + ) + x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) + mbarrier.expect(bar, x_offsets.shape[0] * obj.block_type.nbytes) + tma_blackwell.async_gather(desc, x_offsets, y_offset, bar, alloc) + mbarrier.wait(bar, phase=0) + mbarrier.invalidate(bar) + ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) + out = alloc.load(ret_layout) + return out @gluon.jit def tl_obj_scatter(obj, value, x_offsets, y_offset): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - desc = obj - desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] - alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout, value) - fence_async_shared() - x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( - 0, - ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), - ) - x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) - tma_blackwell.async_scatter(desc, x_offsets, y_offset, alloc) - tma.store_wait(0) - else: - obj.scatter(value, x_offsets, y_offset) + desc = obj + desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] + alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout, value) + fence_async_shared() + x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( + 0, + ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), + ) + x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) + tma_blackwell.async_scatter(desc, x_offsets, y_offset, alloc) + tma.store_wait(0) # ---- NVIDIA host-side descriptor ---- From 1e47994d9ce2a46bd1f90f175d8605cd362261fd Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 15 Apr 2026 13:07:14 -0700 Subject: [PATCH 26/26] Remove translator_helpers.py backward-compat shim Update test imports to use the target-specific helper modules directly and delete the re-export shim. Use lazy import for convert_host_descriptor in tests so the correct target module is loaded at runtime. Made-with: Cursor --- python/test/unit/tools/test_slice_kernel.py | 2 +- python/test/unit/tools/test_triton_to_gluon.py | 17 +++++++++++++---- .../translator_helpers.py | 8 -------- 3 files changed, 14 insertions(+), 13 deletions(-) delete mode 100644 python/triton/tools/triton_to_gluon_translator/translator_helpers.py diff --git a/python/test/unit/tools/test_slice_kernel.py b/python/test/unit/tools/test_slice_kernel.py index d25ca96618a4..e2033a42cdbc 100644 --- a/python/test/unit/tools/test_slice_kernel.py +++ b/python/test/unit/tools/test_slice_kernel.py @@ -302,7 +302,7 @@ def test_slice_kernel_public_imports(): from triton.tools.triton_to_gluon_translator.slice_kernel import slice_kernel as new_slice_kernel from triton.tools.triton_to_gluon_translator.translator import translate_paths from triton.tools.triton_to_gluon_translator.translator import convert_triton_to_gluon - from triton.tools.triton_to_gluon_translator.translator_helpers import convert_host_descriptor + from triton.tools.triton_to_gluon_translator.nvidia_helpers import convert_host_descriptor assert callable(new_slice_kernel) assert callable(translate_paths) diff --git a/python/test/unit/tools/test_triton_to_gluon.py b/python/test/unit/tools/test_triton_to_gluon.py index 474708b0b249..1cc65d07b2d3 100644 --- a/python/test/unit/tools/test_triton_to_gluon.py +++ b/python/test/unit/tools/test_triton_to_gluon.py @@ -7,12 +7,21 @@ from triton.tools.tensor_descriptor import TensorDescriptor from triton.tools.triton_to_gluon_translator.translator import convert_triton_to_gluon -from triton.tools.triton_to_gluon_translator.translator_helpers import convert_host_descriptor from triton.tools.triton_to_gluon_translator.target import TranslatorTarget from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda, is_hip_cdna4, is_hip_gfx1250, is_hip_cdna3_or_newer from triton.language.target_info import current_target +def _convert_host_descriptor(desc): + """Import and call the target-appropriate convert_host_descriptor.""" + target = current_target() + if target is not None and target.backend == "hip": + from triton.tools.triton_to_gluon_translator.amd_helpers import convert_host_descriptor + else: + from triton.tools.triton_to_gluon_translator.nvidia_helpers import convert_host_descriptor + return convert_host_descriptor(desc) + + def convert_kernel(kernel, kernel_name, tmp_path, target=None): if target is None: t = current_target() @@ -182,7 +191,7 @@ def test_triton_to_gluon_descriptor_roundtrip(tmp_path): grid = (1, ) block_shape = [M, N] desc = TensorDescriptor(y, y.shape, y.stride(), block_shape) - gluon_desc = convert_host_descriptor(desc) + gluon_desc = _convert_host_descriptor(desc) kernel[grid](gluon_desc, M, N, 1.0) y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16) @@ -208,8 +217,8 @@ def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path): block_shape = [M, N] in_desc = TensorDescriptor(x, x.shape, x.stride(), block_shape) - gluon_desc = convert_host_descriptor(in_desc) - out_desc = convert_host_descriptor(TensorDescriptor(y, y.shape, y.stride(), block_shape)) + gluon_desc = _convert_host_descriptor(in_desc) + out_desc = _convert_host_descriptor(TensorDescriptor(y, y.shape, y.stride(), block_shape)) kernel[grid](gluon_desc, out_desc, M, N) y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16) diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/translator_helpers.py deleted file mode 100644 index 91ae89264d1c..000000000000 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ /dev/null @@ -1,8 +0,0 @@ -# type: ignore -# Backward-compat shim: re-exports from the split helper modules so that -# existing imports (e.g. ``from translator_helpers import convert_host_descriptor``) -# continue to work. - -from triton.tools.triton_to_gluon_translator.common_helpers import * # noqa: F401,F403 -from triton.tools.triton_to_gluon_translator.nvidia_helpers import * # noqa: F401,F403 -from triton.tools.triton_to_gluon_translator.amd_helpers import * # noqa: F401,F403