diff --git a/python/test/unit/tools/test_slice_kernel.py b/python/test/unit/tools/test_slice_kernel.py index d25ca96618a4..e2033a42cdbc 100644 --- a/python/test/unit/tools/test_slice_kernel.py +++ b/python/test/unit/tools/test_slice_kernel.py @@ -302,7 +302,7 @@ def test_slice_kernel_public_imports(): from triton.tools.triton_to_gluon_translator.slice_kernel import slice_kernel as new_slice_kernel from triton.tools.triton_to_gluon_translator.translator import translate_paths from triton.tools.triton_to_gluon_translator.translator import convert_triton_to_gluon - from triton.tools.triton_to_gluon_translator.translator_helpers import convert_host_descriptor + from triton.tools.triton_to_gluon_translator.nvidia_helpers import convert_host_descriptor assert callable(new_slice_kernel) assert callable(translate_paths) diff --git a/python/test/unit/tools/test_triton_to_gluon.py b/python/test/unit/tools/test_triton_to_gluon.py index 697da0dc5ebd..1cc65d07b2d3 100644 --- a/python/test/unit/tools/test_triton_to_gluon.py +++ b/python/test/unit/tools/test_triton_to_gluon.py @@ -7,12 +7,26 @@ from triton.tools.tensor_descriptor import TensorDescriptor from triton.tools.triton_to_gluon_translator.translator import convert_triton_to_gluon -from triton.tools.triton_to_gluon_translator.translator_helpers import convert_host_descriptor -from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda +from triton.tools.triton_to_gluon_translator.target import TranslatorTarget +from triton._internal_testing import is_blackwell, is_hopper_or_newer, is_cuda, is_hip_cdna4, is_hip_gfx1250, is_hip_cdna3_or_newer +from triton.language.target_info import current_target -def convert_kernel(kernel, kernel_name, tmp_path): - converted = convert_triton_to_gluon([kernel]) +def _convert_host_descriptor(desc): + """Import and call the target-appropriate convert_host_descriptor.""" + target = current_target() + if target is not None and target.backend == "hip": + from triton.tools.triton_to_gluon_translator.amd_helpers import convert_host_descriptor + else: + from triton.tools.triton_to_gluon_translator.nvidia_helpers import convert_host_descriptor + return convert_host_descriptor(desc) + + +def convert_kernel(kernel, kernel_name, tmp_path, target=None): + if target is None: + t = current_target() + target = TranslatorTarget("nvidia" if t.backend == "cuda" else t.arch) + converted = convert_triton_to_gluon([kernel], target=target) # Write converted kernel to a file so @gluon.jit can retrieve source mod_path = tmp_path / "converted_kernel.py" @@ -36,7 +50,6 @@ def add_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr): tl.store(out_ptr + offsets, x + y) -@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA") def test_simple_kernel(tmp_path): kernel = convert_kernel(add_kernel, "add_kernel", tmp_path) @@ -70,9 +83,9 @@ def matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M: tl.constexpr, BLOCK_N: tl.c impl_matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M, BLOCK_N, BLOCK_K) -@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") def test_triton_to_gluon_dot_minimal(tmp_path): - # Convert directly from the Triton kernel object + if not (is_blackwell() or is_hip_cdna3_or_newer() or is_hip_gfx1250()): + pytest.skip("Requires Blackwell, CDNA3+, or gfx1250") kernel = convert_kernel(matmul_tile_kernel, "matmul_tile_kernel", tmp_path) M, N, K = 128, 128, 128 a = torch.randn((M, K), device="cuda", dtype=torch.float16) @@ -131,8 +144,9 @@ def matmul_kernel( # @pytest.mark.parametrize("dtype_dst_str", ["float32"]) @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES", [(128, 128, 64, 1)]) @pytest.mark.parametrize("NUM_WARPS", [4]) -@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, tmp_path): + if not (is_blackwell() or is_hip_cdna4() or is_hip_gfx1250()): + pytest.skip("Requires Blackwell, CDNA4, or gfx1250") device = "cuda" M, N, K = 1024, 512, 256 torch.manual_seed(42) @@ -161,8 +175,15 @@ def descriptor_store_kernel(desc, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, desc.store([0, 0], tile) -@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer") +def _skip_unless_descriptor_target(): + if is_cuda() and not is_hopper_or_newer(): + pytest.skip("Requires Hopper+") + elif not is_cuda() and not is_hip_gfx1250(): + pytest.skip("Requires descriptor support") + + def test_triton_to_gluon_descriptor_roundtrip(tmp_path): + _skip_unless_descriptor_target() kernel = convert_kernel(descriptor_store_kernel, "descriptor_store_kernel", tmp_path) M = N = 64 @@ -170,7 +191,7 @@ def test_triton_to_gluon_descriptor_roundtrip(tmp_path): grid = (1, ) block_shape = [M, N] desc = TensorDescriptor(y, y.shape, y.stride(), block_shape) - gluon_desc = convert_host_descriptor(desc) + gluon_desc = _convert_host_descriptor(desc) kernel[grid](gluon_desc, M, N, 1.0) y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16) @@ -185,8 +206,8 @@ def descriptor_copy_kernel(in_desc, out_desc, BLOCK_M: tl.constexpr, BLOCK_N: tl out_desc.store([0, 0], tile) -@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer") def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path): + _skip_unless_descriptor_target() kernel = convert_kernel(descriptor_copy_kernel, "descriptor_copy_kernel", tmp_path) M = N = 64 @@ -196,8 +217,8 @@ def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path): block_shape = [M, N] in_desc = TensorDescriptor(x, x.shape, x.stride(), block_shape) - gluon_desc = convert_host_descriptor(in_desc) - out_desc = convert_host_descriptor(TensorDescriptor(y, y.shape, y.stride(), block_shape)) + gluon_desc = _convert_host_descriptor(in_desc) + out_desc = _convert_host_descriptor(TensorDescriptor(y, y.shape, y.stride(), block_shape)) kernel[grid](gluon_desc, out_desc, M, N) y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16) @@ -224,8 +245,8 @@ def make_tensor_descriptor_copy_kernel(x_ptr, y_ptr, M, N, BLOCK_M: tl.constexpr out_desc.store([0, 0], tile) -@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer") def test_triton_to_gluon_make_tensor_descriptor(tmp_path, with_allocator): + _skip_unless_descriptor_target() kernel = convert_kernel(make_tensor_descriptor_copy_kernel, "make_tensor_descriptor_copy_kernel", tmp_path) M = N = 64 @@ -258,7 +279,6 @@ def reshape_trans_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr, @pytest.mark.parametrize("TRANS_KIND", ["trans_method", "tl_trans_separate", "tl_trans_tuple", "tl_trans"]) -@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA") def test_triton_reshape_trans(tmp_path, TRANS_KIND): kernel = convert_kernel(reshape_trans_kernel, "reshape_trans_kernel", tmp_path) @@ -289,7 +309,6 @@ def split_kernel(x_ptr, out_ptr): tl.store(p, a) -@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA") def test_split(tmp_path): kernel = convert_kernel(split_kernel, "split_kernel", tmp_path) @@ -339,7 +358,6 @@ def reduce_to_scalar_kernel(out_ptr): tl.store(out_ptr, x) -@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA") def test_reduce_to_scalar(tmp_path): kernel = convert_kernel(reduce_to_scalar_kernel, "reduce_to_scalar_kernel", tmp_path) grid = (1, ) @@ -403,7 +421,6 @@ def atomic_add_kernel(out_ptr, BLOCK: tl.constexpr): tl.atomic_add(out_ptr + idx, idx, mask=scalar_mask, sem="release", scope="cta") -@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA") def test_atomic_add(tmp_path): kernel = convert_kernel(atomic_add_kernel, "atomic_add_kernel", tmp_path) @@ -414,3 +431,58 @@ def test_atomic_add(tmp_path): out = torch.zeros((block, ), device="cuda") kernel[(1, )](out, BLOCK=block) torch.testing.assert_close(out, ref) + + +# ---- additional op coverage ---- + + +@triton.jit +def cat_kernel(x_ptr, y_ptr, out_ptr, BLOCK: tl.constexpr): + offs = tl.arange(0, BLOCK) + x = tl.load(x_ptr + offs) + y = tl.load(y_ptr + offs) + z = tl.cat(x, y, can_reorder=True) + tl.store(out_ptr + tl.arange(0, 2 * BLOCK), z) + + +def test_cat(tmp_path): + kernel = convert_kernel(cat_kernel, "cat_kernel", tmp_path) + + BLOCK = 256 + x = torch.randn(BLOCK, device="cuda", dtype=torch.float32) + y = torch.randn(BLOCK, device="cuda", dtype=torch.float32) + out = torch.empty(2 * BLOCK, device="cuda", dtype=torch.float32) + kernel[(1, )](x, y, out, BLOCK) + + ref = torch.empty_like(out) + cat_kernel[(1, )](x, y, ref, BLOCK) + torch.testing.assert_close(sorted(out.cpu()), sorted(ref.cpu()), atol=0, rtol=0) + + +@triton.jit +def gather_scatter_roundtrip_kernel(out_ptr, in_ptr, idx_ptr, X: tl.constexpr, Y: tl.constexpr, BLOCK_X: tl.constexpr, + BLOCK_Y: tl.constexpr): + idx = tl.load(idx_ptr + tl.arange(0, BLOCK_X)) + in_desc = tl.make_tensor_descriptor(in_ptr, [X, Y], [Y, 1], [1, BLOCK_Y]) + out_desc = tl.make_tensor_descriptor(out_ptr, [X, Y], [Y, 1], [1, BLOCK_Y]) + data = in_desc.gather(idx, 0) + out_desc.scatter(data, idx, 0) + + +# TODO: parametrize over _descriptor_targets once NVIDIA gather/scatter translation is supported. +# The translator currently routes gather/scatter to AMD-specific helpers for AMD targets only. +@pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires gfx1250") +def test_gather_scatter_roundtrip(tmp_path): + kernel = convert_kernel(gather_scatter_roundtrip_kernel, "gather_scatter_roundtrip_kernel", tmp_path, + target=TranslatorTarget.GFX1250) + + X, Y, BLOCK_X, BLOCK_Y = 64, 64, 8, 64 + inp = torch.arange(X * Y, device="cuda", dtype=torch.float16).reshape(X, Y) + idx = torch.tensor([0, 2, 4, 6, 1, 3, 5, 7], device="cuda", dtype=torch.int32) + out = torch.zeros((X, Y), device="cuda", dtype=torch.float16) + kernel[(1, )](out, inp, idx, X, Y, BLOCK_X, BLOCK_Y) + + expected = torch.zeros_like(out) + for i, row in enumerate(idx.tolist()): + expected[row] = inp[row] + torch.testing.assert_close(out, expected, atol=0, rtol=0) diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index 8c52bb81d877..ba32f860e002 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -102,6 +102,10 @@ def is_hip_gfx1250(): return target is not None and target.backend == 'hip' and 'gfx1250' in target.arch +def is_hip_cdna3_or_newer(): + return is_hip_cdna3() or is_hip_cdna4() + + def is_hip_cdna(): return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4() diff --git a/python/triton/tools/triton_to_gluon_translator/amd_helpers.py b/python/triton/tools/triton_to_gluon_translator/amd_helpers.py new file mode 100644 index 000000000000..091051f1f538 --- /dev/null +++ b/python/triton/tools/triton_to_gluon_translator/amd_helpers.py @@ -0,0 +1,410 @@ +# type: ignore + +import math + +import triton.language as tl +from triton.experimental import gluon +from triton.experimental.gluon import language as ttgl +from triton.experimental.gluon.language.amd.gfx1250 import wmma as amd_wmma +from triton.experimental.gluon.language.amd.gfx1250 import tdm as amd_tdm +from triton.experimental.gluon.language.amd.cdna3 import mfma as amd_mfma +from triton.language.target_info import current_target + +from triton.tools.triton_to_gluon_translator.common_helpers import * # noqa: F401,F403 +from triton.tools.triton_to_gluon_translator.common_helpers import ( + default_blocked_layout, + get_num_threads_per_warp, + tl_dot_decomposed_scale_arg, + tl_trans, +) + +# ---- architecture detection ---- + + +@gluon.constexpr_function +def _is_gfx1250(target=None): + return target is not None and target.arch == "gfx1250" + + +@gluon.constexpr_function +def _is_cdna(target=None): + return target is not None and target.arch in ("gfx942", "gfx950") + + +@gluon.constexpr_function +def _cdna_version(target=None): + """Returns 3 for gfx942, 4 for gfx950.""" + return 4 if target is not None and target.arch == "gfx950" else 3 + + +# ---- AMD WMMA layout helpers (gfx1250) ---- + + +@gluon.constexpr_function +def compute_warp_bases(num_warps): + """Distribute warps across M/N: first bit to N, rest to M.""" + n_bits = int(math.log2(num_warps)) + if n_bits == 0: + return [] + warp_bases = [[0, 1]] + for i in range(n_bits - 1): + warp_bases.append([1 << i, 0]) + return warp_bases + + +@gluon.constexpr_function +def get_wmma_layout(shape, num_warps): + warp_bases = compute_warp_bases(num_warps) + return ttgl.amd.AMDWMMALayout(3, True, warp_bases, [], [16, 16, 32]) + + +@gluon.constexpr_function +def get_wmma_k_width(a_ty, b_ty): + min_bitwidth = min(a_ty.element_ty.primitive_bitwidth, b_ty.element_ty.primitive_bitwidth) + return max(128 // min_bitwidth, 1) + + +# ---- AMD MFMA layout helpers (cdna3/cdna4) ---- + + +@gluon.constexpr_function +def get_mfma_instr_k(element_bitwidth, target=None): + """K dimension of the MFMA instruction for [32, 32, K].""" + k_bits = 128 if _cdna_version(target) == 3 else 256 + return k_bits // element_bitwidth + + +@gluon.constexpr_function +def get_mfma_layout(num_warps, element_bitwidth, target=None): + instr_k = get_mfma_instr_k(element_bitwidth, target) + return ttgl.amd.AMDMFMALayout( + version=_cdna_version(target), + instr_shape=[32, 32, instr_k], + transposed=True, + warps_per_cta=[num_warps, 1], + ) + + +@gluon.constexpr_function +def get_mfma_k_width(a_ty, b_ty, target=None): + min_bitwidth = min(a_ty.element_ty.primitive_bitwidth, b_ty.element_ty.primitive_bitwidth) + instr_k = get_mfma_instr_k(min_bitwidth, target) + return instr_k // 2 + + +# ---- AMD dot paths ---- + + +@gluon.jit +def tl_dot_wmma(a, b, acc, out_dtype): + """gfx1250 WMMA path.""" + M: ttgl.constexpr = a.type.shape[0] + N: ttgl.constexpr = b.type.shape[1] + num_warps: ttgl.constexpr = ttgl.num_warps() + + wmma_layout: ttgl.constexpr = get_wmma_layout([M, N], num_warps) + k_width: ttgl.constexpr = get_wmma_k_width(a.type, b.type) + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=wmma_layout, k_width=k_width) + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=wmma_layout, k_width=k_width) + + a = ttgl.convert_layout(a, a_layout) + b = ttgl.convert_layout(b, b_layout) + + if acc is not None: + accumulator = ttgl.convert_layout(acc, wmma_layout) + else: + accumulator = ttgl.zeros([M, N], out_dtype, layout=wmma_layout) + + result = amd_wmma(a, b, accumulator) + + if acc is not None: + ret_layout: ttgl.constexpr = acc.type.layout + else: + ret_layout: ttgl.constexpr = default_blocked_layout(result.type.shape, num_warps) + return ttgl.convert_layout(result, ret_layout) + + +@gluon.jit +def tl_dot_mfma(a, b, acc, out_dtype): + """CDNA3/CDNA4 MFMA path.""" + M: ttgl.constexpr = a.type.shape[0] + N: ttgl.constexpr = b.type.shape[1] + num_warps: ttgl.constexpr = ttgl.num_warps() + min_bitwidth: ttgl.constexpr = min(a.type.element_ty.primitive_bitwidth, b.type.element_ty.primitive_bitwidth) + target: ttgl.constexpr = current_target() + + mfma_layout: ttgl.constexpr = get_mfma_layout(num_warps, min_bitwidth, target) + k_width: ttgl.constexpr = get_mfma_k_width(a.type, b.type, target) + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=k_width) + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=k_width) + + a = ttgl.convert_layout(a, a_layout) + b = ttgl.convert_layout(b, b_layout) + + if acc is not None: + accumulator = ttgl.convert_layout(acc, mfma_layout) + else: + accumulator = ttgl.zeros([M, N], out_dtype, layout=mfma_layout) + + result = amd_mfma(a, b, accumulator) + + if acc is not None: + ret_layout: ttgl.constexpr = acc.type.layout + else: + ret_layout: ttgl.constexpr = default_blocked_layout(result.type.shape, num_warps) + return ttgl.convert_layout(result, ret_layout) + + +# ---- AMD dot dispatch ---- + + +@gluon.jit +def tl_dot( + a, + b, + acc=None, + input_precision=None, + allow_tf32=None, + max_num_imprecise_acc=None, + out_dtype=ttgl.float32, +): + target: ttgl.constexpr = current_target() + if _is_gfx1250(target): + return tl_dot_wmma(a, b, acc, out_dtype) + elif _is_cdna(target): + return tl_dot_mfma(a, b, acc, out_dtype) + + +# Defined here (not imported from common) so __globals__ resolves tl_dot to this module's version. +@gluon.jit +def tl_dot_decomposed_block_scales( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc=None, + fast_math=False, + lhs_k_pack=True, + rhs_k_pack=True, + out_dtype=ttgl.float32, +): + if lhs_scale is None and rhs_scale is not None: + lhs_trans = tl_trans(lhs) + rhs_trans = tl_trans(rhs) + if acc is not None: + orig_layout: ttgl.constexpr = acc.type.layout + acc = tl_trans(acc) + result = tl_dot_scaled( + rhs_trans, + rhs_scale, + rhs_format, + lhs_trans, + lhs_scale, + lhs_format, + acc, + fast_math, + lhs_k_pack, + rhs_k_pack, + out_dtype, + ) + result = tl_trans(result) + if acc is not None: + result = ttgl.convert_layout(result, orig_layout) + return result + else: + ttgl.static_assert(not (not lhs_k_pack or not rhs_k_pack), "TODO: support m/n packed formats") + compute_type: ttgl.constexpr = (ttgl.float16 if + (lhs_format == "fp16" or rhs_format == "fp16") else ttgl.bfloat16) + + scale_a = tl_dot_decomposed_scale_arg(lhs, lhs_scale, lhs_format, 0, compute_type, fast_math) + scale_b = tl_dot_decomposed_scale_arg(rhs, rhs_scale, rhs_format, 1, compute_type, fast_math) + + return tl_dot(scale_a, scale_b, acc, out_dtype=out_dtype) + + +@gluon.jit +def tl_dot_scaled( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc=None, + fast_math=False, + lhs_k_pack=True, + rhs_k_pack=True, + out_dtype=ttgl.float32, +): + return tl_dot_decomposed_block_scales( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc, + fast_math, + lhs_k_pack, + rhs_k_pack, + out_dtype, + ) + + +# ---- AMD TDM tensor descriptors (gfx1250 only) ---- + + +@gluon.constexpr_function +def get_default_tdm_layout(*block_shape): + block_shape = list(block_shape) + return ttgl.PaddedSharedLayout.with_identity_for( + [[block_shape[-1], 4]], + block_shape, + list(range(len(block_shape) - 1, -1, -1)), + ) + + +@tl.core._aggregate +class AMDTensorDescriptorArgs: + """Wraps a real TDM descriptor alongside the original base pointer. + + The base_ptr is needed by gather/scatter to recreate the descriptor with a different + block_shape -- Triton uses block_shape=[1, N] but TDM hardware requires [num_indices, N]. + Shape, strides, and block_shape are read from desc (type metadata gives plain Python ints + for block_shape, tuples for shape/strides).""" + desc: amd_tdm.tensor_descriptor + base_ptr: tl.core.tensor + + +@gluon.jit +def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option: ttgl.constexpr = "zero"): + ttgl.static_assert(_is_gfx1250(current_target()), "tl_make_tensor_descriptor requires gfx1250 target") + layout: ttgl.constexpr = get_default_tdm_layout(*block_shape) + desc = amd_tdm.make_tensor_descriptor(base, shape, strides, block_shape, layout) + return AMDTensorDescriptorArgs(desc, base) + + +# ---- AMD obj dispatch ---- + + +@gluon.jit +def tl_obj_load_amd(desc, offsets): + smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout) + amd_tdm.async_load(desc, offsets, smem) + amd_tdm.async_wait(0) + ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) + return smem.load(ret_layout) + + +@gluon.jit +def tl_obj_store_amd(desc, offsets, value): + smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value) + amd_tdm.async_store(desc, offsets, smem) + amd_tdm.async_wait(0) + + +@gluon.jit +def tl_obj_store(obj, offsets, value): + if isinstance(obj, AMDTensorDescriptorArgs): + tl_obj_store_amd(obj.desc, offsets, value) + elif isinstance(obj, amd_tdm.tensor_descriptor): + tl_obj_store_amd(obj, offsets, value) + else: + return obj.store(offsets, value) + + +@gluon.jit +def tl_obj_load(obj, offsets): + if isinstance(obj, AMDTensorDescriptorArgs): + return tl_obj_load_amd(obj.desc, offsets) + elif isinstance(obj, amd_tdm.tensor_descriptor): + return tl_obj_load_amd(obj, offsets) + else: + return obj.load(offsets) + + +@gluon.jit +def tl_obj_gather_amd(desc_args, x_offsets, y_offset): + NUM_IDX: ttgl.constexpr = x_offsets.shape[0] + BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] + smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) + gather_block_shape: ttgl.constexpr = [NUM_IDX, BLOCK_N] + gather_desc = amd_tdm.make_tensor_descriptor(desc_args.base_ptr, desc_args.desc.shape, desc_args.desc.strides, + gather_block_shape, smem_layout) + num_warps: ttgl.constexpr = ttgl.num_warps() + gather_shape: ttgl.constexpr = gather_desc.block_shape + idx_base: ttgl.constexpr = ttgl.BlockedLayout([gather_shape[0], 1], + [1, get_num_threads_per_warp(current_target())], [1, num_warps], + [1, 0]) + idx_layout: ttgl.constexpr = ttgl.SliceLayout(1, idx_base) + x_offsets = ttgl.convert_layout(x_offsets, idx_layout) + alloc = ttgl.allocate_shared_memory(desc_args.desc.dtype, list(gather_shape), smem_layout) + y_off = ttgl.to_tensor(y_offset) + amd_tdm.async_gather(gather_desc, x_offsets, y_off, alloc) + amd_tdm.async_wait(0) + ret_layout: ttgl.constexpr = default_blocked_layout(list(gather_shape), num_warps, current_target()) + out = alloc.load(ret_layout) + return out + + +@gluon.jit +def tl_obj_scatter_amd(desc_args, value, x_offsets, y_offset): + NUM_IDX: ttgl.constexpr = x_offsets.shape[0] + BLOCK_N: ttgl.constexpr = desc_args.desc.block_shape[1] + smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) + scatter_block_shape: ttgl.constexpr = [NUM_IDX, BLOCK_N] + scatter_desc = amd_tdm.make_tensor_descriptor(desc_args.base_ptr, desc_args.desc.shape, desc_args.desc.strides, + scatter_block_shape, smem_layout) + num_warps: ttgl.constexpr = ttgl.num_warps() + scatter_shape: ttgl.constexpr = scatter_desc.block_shape + idx_base: ttgl.constexpr = ttgl.BlockedLayout([scatter_shape[0], 1], + [1, get_num_threads_per_warp(current_target())], [1, num_warps], + [1, 0]) + idx_layout: ttgl.constexpr = ttgl.SliceLayout(1, idx_base) + x_offsets = ttgl.convert_layout(x_offsets, idx_layout) + alloc = ttgl.allocate_shared_memory(desc_args.desc.dtype, list(scatter_shape), smem_layout, value) + y_off = ttgl.to_tensor(y_offset) + amd_tdm.async_scatter(scatter_desc, x_offsets, y_off, alloc) + amd_tdm.async_wait(0) + + +@gluon.jit +def tl_obj_gather(obj, x_offsets, y_offset): + if isinstance(obj, AMDTensorDescriptorArgs): + return tl_obj_gather_amd(obj, x_offsets, y_offset) + else: + return obj.gather(x_offsets, y_offset) + + +@gluon.jit +def tl_obj_scatter(obj, value, x_offsets, y_offset): + if isinstance(obj, AMDTensorDescriptorArgs): + tl_obj_scatter_amd(obj, value, x_offsets, y_offset) + else: + obj.scatter(value, x_offsets, y_offset) + + +# ---- AMD host-side descriptor ---- + + +def convert_host_descriptor(desc): + + def torch_dtype_to_triton(dtype): + import torch + + if dtype == torch.float8_e5m2: + return ttgl.float8e5 + if dtype == torch.float8_e4m3fn: + return ttgl.float8e4nv + return getattr(ttgl, str(dtype).split(".")[1]) + + from triton.tools.tensor_descriptor import TensorDescriptor + + assert isinstance(desc, TensorDescriptor) + block_shape = desc.block_shape + tensor = desc.base + + layout = get_default_tdm_layout(*block_shape) + return gluon.amd.gfx1250.TensorDescriptor(tensor, list(desc.shape), list(desc.strides), block_shape, layout) diff --git a/python/triton/tools/triton_to_gluon_translator/common_helpers.py b/python/triton/tools/triton_to_gluon_translator/common_helpers.py new file mode 100644 index 000000000000..7442a9287332 --- /dev/null +++ b/python/triton/tools/triton_to_gluon_translator/common_helpers.py @@ -0,0 +1,273 @@ +# type: ignore + +from triton.experimental import gluon +from triton.experimental.gluon import language as ttgl + +# hack to workaround limited dependencies tracking. +# TODO: fix this by pulling imports into the generated file. +from triton.language.target_info import current_target # noqa: F401 + +# ---- layout utilities ---- + + +@gluon.constexpr_function +def get_num_threads_per_warp(target=None) -> ttgl.constexpr: + if target is None: + target = current_target() + if target is not None and target.backend == "hip": + gfx_major = int(target.arch[3:-2]) + return ttgl.constexpr(32 if gfx_major >= 10 else 64) + return ttgl.constexpr(32) + + +@gluon.jit +def get_num_threads_per_program(): + return ttgl.num_warps() * get_num_threads_per_warp(current_target()) + + +@gluon.constexpr_function +def default_blocked_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr, target=None) -> ttgl.constexpr: + rank = len(shape) + size_per_thread = [1] * rank + threads_per_warp = [1] * rank + # TODO: pick a better layout based on shape. Using this allows to not have to convert layout when broadcasting but may blow up register pressure. + threads_per_warp[rank - 1] = get_num_threads_per_warp(target) + warps_per_cta = [1] * rank + warps_per_cta[0] = num_warps + order = list(range(rank - 1, -1, -1)) + return ttgl.BlockedLayout( + size_per_thread=size_per_thread, + threads_per_warp=threads_per_warp, + warps_per_cta=warps_per_cta, + order=order, + ) + + +@gluon.constexpr_function +def get_swizzle_byte_width(bitwidth): + swizzle = min(bitwidth, 128) + swizzle = 0 if swizzle < 32 else swizzle + return swizzle + + +@gluon.constexpr_function +def get_int_type(bitwidth): + if bitwidth == 64: + return ttgl.int64 + elif bitwidth == 32: + return ttgl.int32 + elif bitwidth == 16: + return ttgl.int16 + elif bitwidth == 8: + return ttgl.int8 + else: + assert False, f"Unsupported bitwidth: {bitwidth}" + + +# ---- portable ops ---- + + +@gluon.jit +def tl_arange(start: ttgl.constexpr, stop: ttgl.constexpr = None): + layout: ttgl.constexpr = default_blocked_layout([stop - start], ttgl.num_warps()) + return ttgl.arange(start, stop, layout=layout) + + +@gluon.jit +def tl_full(shape, value, dtype=None): + layout: ttgl.constexpr = default_blocked_layout(shape, ttgl.num_warps()) + return ttgl.full(shape, value, dtype, layout=layout) + + +@gluon.jit +def tl_trans(value, *dims): + return value.trans(*dims) + + +@gluon.constexpr_function +def cat_permute_order(rank, dim): + order = list(range(rank)) + order.insert(dim, rank) + return order + + +@gluon.constexpr_function +def cat_result_shape(input_shape, dim): + result_shape = list(input_shape) + result_shape[dim] *= 2 + return result_shape + + +@gluon.jit +def tl_cat(input, other, can_reorder=False, dim=0): + c = ttgl.join(input, other) + order: ttgl.constexpr = cat_permute_order(len(input.shape), dim) + c = ttgl.permute(c, order) + shape: ttgl.constexpr = cat_result_shape(input.shape, dim) + c = ttgl.reshape(c, shape) + return reset_to_default_layout(c) + + +@gluon.jit +def reset_to_default_layout(value): + ty: ttgl.constexpr = value.type + if isinstance(ty, ttgl.tuple_type): + out = () + for i in ttgl.static_range(len(value)): + r = ttgl.convert_layout(value[i], layout=default_blocked_layout(value[i].type.shape, ttgl.num_warps())) + out = out + (r, ) + return out + elif isinstance(value, ttgl.tensor) and isinstance(value.type, ttgl.distributed_type): + layout: ttgl.constexpr = default_blocked_layout(ty.shape, ttgl.num_warps()) + return ttgl.convert_layout(value, layout=layout) + else: + return value + + +@gluon.constexpr_function +def get_split_src_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr, target=None) -> ttgl.constexpr: + rank = len(shape) + size_per_thread = [1 if i != rank - 1 else 2 for i in range(rank)] + threads_per_warp = [1 for _ in range(rank)] + remaining_threads = get_num_threads_per_warp(target) + for dim in range(rank - 2, -1, -1): + threads_per_warp[dim] = min(shape[dim], remaining_threads) + remaining_threads = remaining_threads // threads_per_warp[dim] + warps_per_cta = [1 for _ in range(rank)] + warps_per_cta[0] = num_warps + order = list(range(rank - 1, -1, -1)) + return ttgl.BlockedLayout( + size_per_thread=size_per_thread, + threads_per_warp=threads_per_warp, + warps_per_cta=warps_per_cta, + order=order, + ) + + +@gluon.jit +def set_split_src_layout(value): + layout: ttgl.constexpr = get_split_src_layout(value.type.shape, ttgl.num_warps()) + return ttgl.convert_layout(value, layout=layout) + + +@gluon.constexpr_function +def build_expand_dims_layout(shape, expand_dims, num_warps): + if isinstance(shape, ttgl.tuple): + shape = shape.values + assert isinstance(shape, list), (f"expected shape to be a list, got {shape} which is {type(shape)}") + parent_shape = list(shape) + for dim in expand_dims: + parent_shape.insert(dim, 1) + layout = default_blocked_layout(parent_shape, num_warps) + for dim in reversed(expand_dims): + layout = ttgl.SliceLayout(dim=dim, parent=layout) + return layout + + +@gluon.jit +def convert_to_expand_dims_layout(value, expand_dims: list[int]): + layout: ttgl.constexpr = build_expand_dims_layout(value.shape, expand_dims, ttgl.num_warps()) + return ttgl.convert_layout(value, layout) + + +# ---- dot-scaled sub-helpers (vendor-neutral) ---- + + +@gluon.jit +def tl_dot_decomposed_scale_to_16(scale, compute_type): + large_fp_type: ttgl.constexpr = ttgl.float32 if compute_type == ttgl.float16 else compute_type + int_width: ttgl.constexpr = large_fp_type.primitive_bitwidth + int_type: ttgl.constexpr = get_int_type(int_width) + + zexted = ttgl.cast(scale, int_type) + shift_value: ttgl.constexpr = large_fp_type.fp_mantissa_width + shl_res = zexted << shift_value + scale_fp = ttgl.cast(shl_res, large_fp_type, bitcast=True) + if large_fp_type != compute_type: + scale_fp = ttgl.cast(scale_fp, compute_type) + return scale_fp + + +@gluon.constexpr_function +def tl_dot_get_expand_dims_layout(scale_ty, num_warps, rank): + shape = scale_ty.shape.values + [1] + blocked = default_blocked_layout(shape, num_warps) + slice = ttgl.SliceLayout(rank, blocked) + return slice + + +@gluon.constexpr_function +def tl_dot_get_permute_order(rank, dim): + order = list(range(rank)) + order.insert(dim + 1, rank) + return order + + +@gluon.constexpr_function +def tl_dot_get_reshape_shape(scale_ty, dim): + shape = list(scale_ty.shape.values) + shape.pop() + shape[dim] *= 32 + return shape + + +@gluon.jit +def tl_dot_decomposed_broadcast_scale(scale, dim): + scale_ty: ttgl.constexpr = scale.type + rank: ttgl.constexpr = len(scale_ty.shape) + + num_warps: ttgl.constexpr = ttgl.num_warps() + slice_enc: ttgl.constexpr = tl_dot_get_expand_dims_layout(scale_ty, num_warps, rank) + scale = ttgl.convert_layout(scale, slice_enc) + expand_scale = scale.expand_dims(rank) + broadcast_scale = expand_scale.broadcast_to(scale.type.shape + (32, )) + permute_order: ttgl.constexpr = tl_dot_get_permute_order(rank, dim) + transposed_scale = broadcast_scale.permute(permute_order) + reshape_shape: ttgl.constexpr = tl_dot_get_reshape_shape(broadcast_scale.type, dim) + return transposed_scale.reshape(reshape_shape) + + +@gluon.constexpr_function +def tl_dot_decomposed_get_transposed_order(rank): + assert rank >= 2 + order = list(range(rank - 2)) + order += [rank - 1, rank - 2] + return order + + +@gluon.jit +def tl_dot_decomposed_extend_and_broadcast_scale(v, scale, compute_type, operand_index): + rank: ttgl.constexpr = len(v.type.shape) + k_dim: ttgl.constexpr = rank - 1 if operand_index == 0 else rank - 2 + + if operand_index == 1: + order: ttgl.constexpr = tl_dot_decomposed_get_transposed_order(rank) + scale = ttgl.permute(scale, order) + + scale16 = tl_dot_decomposed_scale_to_16(scale, compute_type) + reshape_scale = tl_dot_decomposed_broadcast_scale(scale16, k_dim) + return ttgl.convert_layout(reshape_scale, v.type.layout), scale + + +@gluon.jit +def tl_dot_decomposed_mask_nan(mxfp, scale, fast_math): + ttgl.static_assert(fast_math, "TODO: support non-fast-math") + return mxfp + + +@gluon.jit +def tl_dot_decomposed_scale_arg(v, scale, arg_format, operand_index, compute_type, fast_math): + is_fp4: ttgl.constexpr = arg_format == "e2m1" + rank: ttgl.constexpr = len(v.type.shape) + k_dim: ttgl.constexpr = rank - 1 if operand_index == 0 else rank - 2 + + if is_fp4: + v = ttgl.fp4_to_fp(v, compute_type, k_dim) + else: + v = ttgl.cast(v, compute_type) + if scale is None: + return v + else: + reshape_scale, scale = tl_dot_decomposed_extend_and_broadcast_scale(v, scale, compute_type, operand_index) + mxfp = ttgl.mul(v, reshape_scale) + return tl_dot_decomposed_mask_nan(mxfp, scale, fast_math) diff --git a/python/triton/tools/triton_to_gluon_translator/inline_helpers.py b/python/triton/tools/triton_to_gluon_translator/inline_helpers.py index 93f91a8ecfef..235add27f869 100644 --- a/python/triton/tools/triton_to_gluon_translator/inline_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/inline_helpers.py @@ -1,25 +1,42 @@ +_torch_dtype_to_triton_def = R""" +def _torch_dtype_to_triton(dtype): + import torch + + if dtype == torch.float8_e5m2: + return gl.float8e5 + if dtype == torch.float8_e4m3fn: + return gl.float8e4nv + return getattr(gl, str(dtype).split(".")[1]) +""" + defs: dict[str, str] = { "convert_host_descriptor": - R""" + _torch_dtype_to_triton_def + R""" def convert_host_descriptor(desc): - def torch_dtype_to_triton(dtype): - import torch - - if dtype == torch.float8_e5m2: - return gl.float8e5 - if dtype == torch.float8_e4m3fn: - return gl.float8e4nv - return getattr(gl, str(dtype).split(".")[1]) - from triton.tools.tensor_descriptor import TensorDescriptor assert isinstance(desc, TensorDescriptor) block_shape = desc.block_shape dtype = desc.base.dtype tensor = desc.base - layout = gl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(dtype)) + layout = gl.NVMMASharedLayout.get_default_for(block_shape, _torch_dtype_to_triton(dtype)) return gluon.nvidia.hopper.TensorDescriptor( tensor, desc.shape, desc.strides, block_shape, layout ) -""" +""", + "convert_host_descriptor_amd": + _torch_dtype_to_triton_def + R""" +def convert_host_descriptor(desc): + from triton.tools.tensor_descriptor import TensorDescriptor + + assert isinstance(desc, TensorDescriptor) + block_shape = desc.block_shape + dtype = desc.base.dtype + layout = gl.PaddedSharedLayout.with_identity_for( + [[block_shape[-1], 4]], list(block_shape), [1, 0] + ) + return gluon.amd.gfx1250.TensorDescriptor( + desc.base, list(desc.shape), list(desc.strides), block_shape, layout + ) +""", } diff --git a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py b/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py similarity index 55% rename from python/triton/tools/triton_to_gluon_translator/translator_helpers.py rename to python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py index bdf94145f654..8791e12df1ab 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py @@ -1,7 +1,5 @@ # type: ignore -from typing import Any - from triton.experimental import gluon from triton.experimental.gluon import language as ttgl from triton.experimental.gluon.language.nvidia.ampere import mma_v2 @@ -16,9 +14,15 @@ from triton.experimental.gluon.language.nvidia.blackwell import tma as tma_blackwell from triton.experimental.gluon.language.nvidia.hopper import fence_async_shared, mbarrier, tma -# hack to workaround limited dependencies tracking. -# TODO: fix this by pulling imports into the generated file. -from triton.language.target_info import current_target # noqa: F401 +from triton.tools.triton_to_gluon_translator.common_helpers import * # noqa: F401,F403 +from triton.tools.triton_to_gluon_translator.common_helpers import ( + default_blocked_layout, + get_num_threads_per_warp, + tl_dot_decomposed_scale_arg, + tl_trans, +) + +# ---- NVIDIA MMA sync (Ampere) ---- @gluon.constexpr_function @@ -62,6 +66,9 @@ def tl_dot_mma_sync(a, b, acc_init=None, input_precision=None, out_dtype=ttgl.fl return result +# ---- NVIDIA Blackwell dot ---- + + @gluon.constexpr_function def tl_dot_mmav5_supported(a_ty, b_ty, num_warps, input_precision, allow_tf32, max_num_imprecise_acc): assert max_num_imprecise_acc is None, ("max_num_imprecise_acc only applies to Hopper warp_group_dot") @@ -148,7 +155,6 @@ def tl_dot_blackwell( a_smem = get_shared_memory_mma_operand(a, 0, allow_transpose) b_smem = get_shared_memory_mma_operand(b, 1, allow_transpose) - # MMA instruction shape m: ttgl.constexpr = 128 if M >= 128 else 64 n: ttgl.constexpr = 256 if N >= 256 else N @@ -170,13 +176,15 @@ def tl_dot_blackwell( mbarrier.wait(bar, phase=0) mbarrier.invalidate(bar) - # Load back from TMEM using a register layout and convert to acc layout out = acc_tmem.load() ret_layout: ttgl.constexpr = default_blocked_layout([M, N], ttgl.num_warps()) out = ttgl.convert_layout(out, ret_layout) return out +# ---- NVIDIA dot dispatch ---- + + @gluon.jit def tl_dot( a, @@ -194,6 +202,9 @@ def tl_dot( return tl_dot_mma_sync(a, b, acc, input_precision, out_dtype) +# ---- NVIDIA dot-scaled ---- + + @gluon.constexpr_function def tl_dot_scaled_mmav5_supported(a_ty, b_ty, num_warps): M = a_ty.shape[0] @@ -204,220 +215,6 @@ def tl_dot_scaled_mmav5_supported(a_ty, b_ty, num_warps): and N >= 16) -@gluon.constexpr_function -def get_swizzle_byte_width(bitwidth): - swizzle = min(bitwidth, 128) - swizzle = 0 if swizzle < 32 else swizzle - return swizzle - - -@gluon.constexpr_function -def get_int_type(bitwidth): - if bitwidth == 64: - return ttgl.int64 - elif bitwidth == 32: - return ttgl.int32 - elif bitwidth == 16: - return ttgl.int16 - elif bitwidth == 8: - return ttgl.int8 - else: - assert False, f"Unsupported bitwidth: {bitwidth}" - - -@gluon.jit -def tl_dot_decomposed_scale_to_16(scale, compute_type): - large_fp_type: ttgl.constexpr = ttgl.float32 if compute_type == ttgl.float16 else compute_type - int_width: ttgl.constexpr = large_fp_type.primitive_bitwidth - int_type: ttgl.constexpr = get_int_type(int_width) - - zexted = ttgl.cast(scale, int_type) - shift_value: ttgl.constexpr = large_fp_type.fp_mantissa_width - shl_res = zexted << shift_value - scale_fp = ttgl.cast(shl_res, large_fp_type, bitcast=True) - if large_fp_type != compute_type: - scale_fp = ttgl.cast(scale_fp, compute_type) - return scale_fp - - -@gluon.constexpr_function -def tl_dot_get_expand_dims_layout(scale_ty, num_warps, rank): - shape = scale_ty.shape.values + [1] - blocked = default_blocked_layout(shape, num_warps) - slice = ttgl.SliceLayout(rank, blocked) - return slice - - -@gluon.constexpr_function -def tl_dot_get_permute_order(rank, dim): - order = list(range(rank)) - order.insert(dim + 1, rank) - return order - - -@gluon.constexpr_function -def tl_dot_get_reshape_shape(scale_ty, dim): - shape = list(scale_ty.shape.values) - shape.pop() - shape[dim] *= 32 - return shape - - -@gluon.jit -def tl_dot_decomposed_broadcast_scale(scale, dim): - scale_ty: ttgl.constexpr = scale.type - rank: ttgl.constexpr = len(scale_ty.shape) - - num_warps: ttgl.constexpr = ttgl.num_warps() - slice_enc: ttgl.constexpr = tl_dot_get_expand_dims_layout(scale_ty, num_warps, rank) - scale = ttgl.convert_layout(scale, slice_enc) - expand_scale = scale.expand_dims(rank) - broadcast_scale = expand_scale.broadcast_to(scale.type.shape + (32, )) - permute_order: ttgl.constexpr = tl_dot_get_permute_order(rank, dim) - transposed_scale = broadcast_scale.permute(permute_order) - reshape_shape: ttgl.constexpr = tl_dot_get_reshape_shape(broadcast_scale.type, dim) - return transposed_scale.reshape(reshape_shape) - - -@gluon.constexpr_function -def tl_dot_decomposed_get_transposed_order(rank): - assert rank >= 2 - order = list(range(rank - 2)) - order += [rank - 1, rank - 2] - return order - - -@gluon.jit -def tl_dot_decomposed_extend_and_broadcast_scale(v, scale, compute_type, operand_index): - rank: ttgl.constexpr = len(v.type.shape) - k_dim: ttgl.constexpr = rank - 1 if operand_index == 0 else rank - 2 - - if operand_index == 1: - order: ttgl.constexpr = tl_dot_decomposed_get_transposed_order(rank) - scale = ttgl.permute(scale, order) - - scale16 = tl_dot_decomposed_scale_to_16(scale, compute_type) - reshape_scale = tl_dot_decomposed_broadcast_scale(scale16, k_dim) - return ttgl.convert_layout(reshape_scale, v.type.layout), scale - - -@gluon.jit -def tl_dot_decomposed_mask_nan(mxfp, scale, fast_math): - ttgl.static_assert(fast_math, "TODO: support non-fast-math") - return mxfp - - -@gluon.jit -def tl_dot_decomposed_scale_arg(v, scale, arg_format, operand_index, compute_type, fast_math): - is_fp4: ttgl.constexpr = arg_format == "e2m1" - rank: ttgl.constexpr = len(v.type.shape) - k_dim: ttgl.constexpr = rank - 1 if operand_index == 0 else rank - 2 - - if is_fp4: - v = ttgl.fp4_to_fp(v, compute_type, k_dim) - else: - v = ttgl.cast(v, compute_type) - if scale is None: - return v - else: - reshape_scale, scale = tl_dot_decomposed_extend_and_broadcast_scale(v, scale, compute_type, operand_index) - mxfp = ttgl.mul(v, reshape_scale) - return tl_dot_decomposed_mask_nan(mxfp, scale, fast_math) - - -@gluon.jit -def tl_dot_scaled( - lhs, - lhs_scale, - lhs_format, - rhs, - rhs_scale, - rhs_format, - acc=None, - fast_math=False, - lhs_k_pack=True, - rhs_k_pack=True, - out_dtype=ttgl.float32, -): - if (tl_dot_scaled_mmav5_supported(lhs.type, rhs.type, ttgl.num_warps()) and lhs_scale is not None - and rhs_scale is not None): - return tl_dot_scaled_blackwell( - lhs, - lhs_scale, - lhs_format, - rhs, - rhs_scale, - rhs_format, - acc, - fast_math, - lhs_k_pack, - rhs_k_pack, - out_dtype, - ) - else: - return tl_dot_decomposed_block_scales( - lhs, - lhs_scale, - lhs_format, - rhs, - rhs_scale, - rhs_format, - acc, - fast_math, - lhs_k_pack, - rhs_k_pack, - out_dtype, - ) - - -@gluon.jit -def tl_dot_decomposed_block_scales( - lhs, - lhs_scale, - lhs_format, - rhs, - rhs_scale, - rhs_format, - acc=None, - fast_math=False, - lhs_k_pack=True, - rhs_k_pack=True, - out_dtype=ttgl.float32, -): - if lhs_scale is None and rhs_scale is not None: - lhs_trans = tl_trans(lhs) - rhs_trans = tl_trans(rhs) - if acc is not None: - orig_layout: ttgl.constexpr = acc.type.layout - acc = tl_trans(acc) - result = tl_dot_scaled( - rhs_trans, - rhs_scale, - rhs_format, - lhs_trans, - lhs_scale, - lhs_format, - acc, - fast_math, - lhs_k_pack, - rhs_k_pack, - out_dtype, - ) - result = tl_trans(result) - if acc is not None: - result = ttgl.convert_layout(result, orig_layout) - return result - else: - ttgl.static_assert(not (not lhs_k_pack or not rhs_k_pack), "TODO: support m/n packed formats") - compute_type: ttgl.constexpr = (ttgl.float16 if - (lhs_format == "fp16" or rhs_format == "fp16") else ttgl.bfloat16) - - scale_a = tl_dot_decomposed_scale_arg(lhs, lhs_scale, lhs_format, 0, compute_type, fast_math) - scale_b = tl_dot_decomposed_scale_arg(rhs, rhs_scale, rhs_format, 1, compute_type, fast_math) - - return tl_dot(scale_a, scale_b, acc, out_dtype=out_dtype) - - @gluon.jit def tl_dot_scaled_blackwell( lhs, @@ -491,106 +288,107 @@ def tl_dot_scaled_blackwell( tcgen05_commit(bar) mbarrier.wait(bar, phase=0) mbarrier.invalidate(bar) - # Load back from TMEM using a register layout and convert to acc layout out = acc_tmem.load() ret_layout: ttgl.constexpr = default_blocked_layout([M, N], ttgl.num_warps()) out = ttgl.convert_layout(out, ret_layout) return out -@gluon.constexpr_function -def get_num_threads_per_warp() -> ttgl.constexpr: - return ttgl.constexpr(32) - - -@gluon.jit -def get_num_threads_per_program(): - return ttgl.num_warps() * get_num_threads_per_warp() - - -@gluon.constexpr_function -def default_blocked_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr) -> ttgl.constexpr: - rank = len(shape) - # 1 element per thread for all dimensions - size_per_thread = [1] * rank - # Distribute 32 threads per warp across dimensions (simple heuristic: last-fastest) - threads_per_warp = [1] * rank - # TODO: pick a better layout based on shape. Using this allows to not have to convert layout when broadcasting but may blow up register pressure. - threads_per_warp[rank - 1] = get_num_threads_per_warp() - # remaining_threads = get_num_threads_per_warp() - # for dim in range(rank - 1, -1, -1): - # threads_per_warp[dim] = min(remaining_threads, shape[dim]) - # remaining_threads = remaining_threads // threads_per_warp[dim] - # Use provided num_warps to distribute warps per CTA (put all on first dim) - warps_per_cta = [1] * rank - warps_per_cta[0] = num_warps - # Natural order [rank-1, rank-2, ..., 0] - order = list(range(rank - 1, -1, -1)) - return ttgl.BlockedLayout( - size_per_thread=size_per_thread, - threads_per_warp=threads_per_warp, - warps_per_cta=warps_per_cta, - order=order, - ) - - +# Defined here (not imported from common) so __globals__ resolves tl_dot to this module's version. @gluon.jit -def tl_obj_store(obj, offsets, value): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - return tl_store_tensor_descriptor(obj, offsets, value) +def tl_dot_decomposed_block_scales( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc=None, + fast_math=False, + lhs_k_pack=True, + rhs_k_pack=True, + out_dtype=ttgl.float32, +): + if lhs_scale is None and rhs_scale is not None: + lhs_trans = tl_trans(lhs) + rhs_trans = tl_trans(rhs) + if acc is not None: + orig_layout: ttgl.constexpr = acc.type.layout + acc = tl_trans(acc) + result = tl_dot_scaled( + rhs_trans, + rhs_scale, + rhs_format, + lhs_trans, + lhs_scale, + lhs_format, + acc, + fast_math, + lhs_k_pack, + rhs_k_pack, + out_dtype, + ) + result = tl_trans(result) + if acc is not None: + result = ttgl.convert_layout(result, orig_layout) + return result else: - return obj.store(offsets, value) + ttgl.static_assert(not (not lhs_k_pack or not rhs_k_pack), "TODO: support m/n packed formats") + compute_type: ttgl.constexpr = (ttgl.float16 if + (lhs_format == "fp16" or rhs_format == "fp16") else ttgl.bfloat16) + scale_a = tl_dot_decomposed_scale_arg(lhs, lhs_scale, lhs_format, 0, compute_type, fast_math) + scale_b = tl_dot_decomposed_scale_arg(rhs, rhs_scale, rhs_format, 1, compute_type, fast_math) -@gluon.jit -def tl_obj_load(obj, offsets): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - return tl_load_tensor_descriptor(obj, offsets) - else: - return obj.load(offsets) + return tl_dot(scale_a, scale_b, acc, out_dtype=out_dtype) @gluon.jit -def tl_obj_gather(obj, x_offsets, y_offset): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - desc = obj - desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] - alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout) - bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) - mbarrier.init(bar, count=1) - x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( - 0, - ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), +def tl_dot_scaled( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc=None, + fast_math=False, + lhs_k_pack=True, + rhs_k_pack=True, + out_dtype=ttgl.float32, +): + if (tl_dot_scaled_mmav5_supported(lhs.type, rhs.type, ttgl.num_warps()) and lhs_scale is not None + and rhs_scale is not None): + return tl_dot_scaled_blackwell( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc, + fast_math, + lhs_k_pack, + rhs_k_pack, + out_dtype, ) - x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) - mbarrier.expect(bar, x_offsets.shape[0] * obj.block_type.nbytes) - tma_blackwell.async_gather(desc, x_offsets, y_offset, bar, alloc) - mbarrier.wait(bar, phase=0) - mbarrier.invalidate(bar) - # Load from shared memory into a register tensor using a reasonable default layout - ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) - out = alloc.load(ret_layout) - return out else: - return obj.gather(x_offsets, y_offset) + return tl_dot_decomposed_block_scales( + lhs, + lhs_scale, + lhs_format, + rhs, + rhs_scale, + rhs_format, + acc, + fast_math, + lhs_k_pack, + rhs_k_pack, + out_dtype, + ) -@gluon.jit -def tl_obj_scatter(obj, value, x_offsets, y_offset): - if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): - desc = obj - desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] - alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout, value) - fence_async_shared() - x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( - 0, - ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), - ) - x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) - tma_blackwell.async_scatter(desc, x_offsets, y_offset, alloc) - tma.store_wait(0) - else: - obj.scatter(value, x_offsets, y_offset) +# ---- NVIDIA TMA tensor descriptors ---- @gluon.jit @@ -613,102 +411,65 @@ def tl_load_tensor_descriptor(desc, offsets): smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout) bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) mbarrier.init(bar, count=1) - # Issue async copy from global (descriptor) to shared memory and wait for completion mbarrier.expect(bar, desc.block_type.nbytes) tma.async_copy_global_to_shared(desc, offsets, bar, smem) mbarrier.wait(bar, phase=0) mbarrier.invalidate(bar) - # Load from shared memory into a register tensor using a reasonable default layout ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) out = smem.load(ret_layout) return out -@gluon.jit -def tl_arange(start: ttgl.constexpr, stop: ttgl.constexpr = None): - layout: ttgl.constexpr = default_blocked_layout([stop - start], ttgl.num_warps()) - return ttgl.arange(start, stop, layout=layout) +# ---- NVIDIA obj dispatch ---- @gluon.jit -def tl_full(shape, value, dtype=None): - layout: ttgl.constexpr = default_blocked_layout(shape, ttgl.num_warps()) - return ttgl.full(shape, value, dtype, layout=layout) +def tl_obj_store(obj, offsets, value): + tl_store_tensor_descriptor(obj, offsets, value) @gluon.jit -def tl_trans(value, *dims): - return value.trans(*dims) - - -@gluon.constexpr_function -def cat_permute_order(rank, dim): - order = list(range(rank)) - order.insert(dim, rank) - return order - - -@gluon.constexpr_function -def cat_result_shape(input_shape, dim): - result_shape = list(input_shape) - result_shape[dim] *= 2 - return result_shape +def tl_obj_load(obj, offsets): + return tl_load_tensor_descriptor(obj, offsets) @gluon.jit -def tl_cat(input, other, can_reorder=False, dim=0): - # Join introduces a new minor dim; move it before the concat dim and merge. - c = ttgl.join(input, other) - order: ttgl.constexpr = cat_permute_order(len(input.shape), dim) - c = ttgl.permute(c, order) - shape: ttgl.constexpr = cat_result_shape(input.shape, dim) - c = ttgl.reshape(c, shape) - return reset_to_default_layout(c) +def tl_obj_gather(obj, x_offsets, y_offset): + desc = obj + desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] + alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( + 0, + ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), + ) + x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) + mbarrier.expect(bar, x_offsets.shape[0] * obj.block_type.nbytes) + tma_blackwell.async_gather(desc, x_offsets, y_offset, bar, alloc) + mbarrier.wait(bar, phase=0) + mbarrier.invalidate(bar) + ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) + out = alloc.load(ret_layout) + return out @gluon.jit -def reset_to_default_layout(value): - ty: ttgl.constexpr = value.type - if isinstance(ty, ttgl.tuple_type): - out = () - for i in ttgl.static_range(len(value)): - r = ttgl.convert_layout(value[i], layout=default_blocked_layout(value[i].type.shape, ttgl.num_warps())) - out = out + (r, ) - return out - elif isinstance(value, ttgl.tensor) and isinstance(value.type, ttgl.distributed_type): - layout: ttgl.constexpr = default_blocked_layout(ty.shape, ttgl.num_warps()) - return ttgl.convert_layout(value, layout=layout) - else: - return value - - -@gluon.constexpr_function -def get_split_src_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr) -> ttgl.constexpr: - rank = len(shape) - size_per_thread = [1 if i != rank - 1 else 2 for i in range(rank)] - # Distribute 32 threads per warp across dimensions (simple heuristic: last-fastest) - threads_per_warp = [1 for _ in range(rank)] - remaining_threads = get_num_threads_per_warp() - for dim in range(rank - 2, -1, -1): - threads_per_warp[dim] = min(shape[dim], remaining_threads) - remaining_threads = remaining_threads // threads_per_warp[dim] - # Use provided num_warps to distribute warps per CTA (put all on first dim) - warps_per_cta = [1 for _ in range(rank)] - warps_per_cta[0] = num_warps - # Natural order [rank-1, rank-2, ..., 0] - order = list(range(rank - 1, -1, -1)) - return ttgl.BlockedLayout( - size_per_thread=size_per_thread, - threads_per_warp=threads_per_warp, - warps_per_cta=warps_per_cta, - order=order, +def tl_obj_scatter(obj, value, x_offsets, y_offset): + desc = obj + desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] + alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout, value) + fence_async_shared() + x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( + 0, + ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), ) + x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) + tma_blackwell.async_scatter(desc, x_offsets, y_offset, alloc) + tma.store_wait(0) -@gluon.jit -def set_split_src_layout(value): - layout: ttgl.constexpr = get_split_src_layout(value.type.shape, ttgl.num_warps()) - return ttgl.convert_layout(value, layout=layout) +# ---- NVIDIA host-side descriptor ---- def convert_host_descriptor(desc): @@ -728,25 +489,6 @@ def torch_dtype_to_triton(dtype): block_shape = desc.block_shape dtype = desc.base.dtype tensor = desc.base + layout = ttgl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(dtype)) return gluon.nvidia.hopper.TensorDescriptor(tensor, desc.shape, desc.strides, block_shape, layout) - - -@gluon.constexpr_function -def build_expand_dims_layout(shape, expand_dims, num_warps): - if isinstance(shape, ttgl.tuple): - shape = shape.values - assert isinstance(shape, list), (f"expected shape to be a list, got {shape} which is {type(shape)}") - parent_shape = list(shape) - for dim in expand_dims: - parent_shape.insert(dim, 1) - layout = default_blocked_layout(parent_shape, num_warps) - for dim in reversed(expand_dims): - layout = ttgl.SliceLayout(dim=dim, parent=layout) - return layout - - -@gluon.jit -def convert_to_expand_dims_layout(value, expand_dims: list[int]) -> Any: - layout: ttgl.constexpr = build_expand_dims_layout(value.shape, expand_dims, ttgl.num_warps()) - return ttgl.convert_layout(value, layout) diff --git a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py index 81af3a89fe3d..cd40f4921047 100644 --- a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py +++ b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py @@ -26,6 +26,7 @@ from triton.tools.triton_to_gluon_translator.ordered_set import ordered_set from triton.tools.triton_to_gluon_translator.scoped_dict import scoped_dict from triton.tools.triton_to_gluon_translator.stable_toposort import stable_toposort +from triton.tools.triton_to_gluon_translator.target import TranslatorTarget logger = logging.getLogger(__name__) @@ -521,6 +522,7 @@ class SliceRewriter(ReferenceRewriter): translate_to_gluon: bool = False inline_helpers: ordered_set[str] = field(default_factory=ordered_set[str]) cvt_context: list[bool] = field(default_factory=lambda: [False]) + target: TranslatorTarget = TranslatorTarget.NVIDIA def __post_init__(self) -> None: # Special rules for sugaring imports. @@ -562,7 +564,7 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST: self.imports.add("import triton.experimental.gluon._runtime as gluon_runtime") new_node = parse_expr("gluon_runtime.GluonJITFunction") elif value is tl.tensor_descriptor: - self.imports.add("from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor") + self.imports.add(self.target.tensor_descriptor_import) new_node = ast.Name(id="tensor_descriptor", ctx=ast.Load()) return new_node @@ -709,6 +711,7 @@ def slice_kernel( leaf_paths: list[str] | None = None, translate_to_gluon: bool = False, ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, + target: TranslatorTarget = TranslatorTarget.NVIDIA, ) -> str: base_values: list[GlobalValue] = [get_base_value(root_path) for root_path in root_paths] base_value_ids: set[int] = set() @@ -736,7 +739,7 @@ def filter(value: ModuleType | GlobalValue) -> bool: ignored_decorator_matchers=ignored_decorator_matchers, ) jit_functions = [fn for fn in jit_functions if not fn.original_value.is_gluon()] - converted_functions = translate_kernels(jit_functions) + converted_functions = translate_kernels(jit_functions, target=target) module_file = tempfile.NamedTemporaryFile(delete=False, prefix="translated_", suffix=".py") module_path = Path(module_file.name) module_path.write_text(converted_functions) @@ -782,6 +785,7 @@ def filter(value: ModuleType | GlobalValue) -> bool: ignored_decorator_matchers=tuple(ignored_decorator_matchers or ()), translate_to_gluon=translate_to_gluon, inline_helpers=inline_helpers, + target=target, ) tree = rewriter.visit(tree) source = ast.unparse(tree) @@ -807,6 +811,7 @@ def slice_kernel_from_trace( translate_to_gluon: bool, extra_modules: dict[str, str], ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, + target: TranslatorTarget = TranslatorTarget.NVIDIA, ) -> str: module_remap: dict[str, str] = {} for name, path in extra_modules.items(): @@ -831,6 +836,7 @@ def slice_kernel_from_trace( leaf_paths=sorted(leaf_paths), translate_to_gluon=translate_to_gluon, ignored_decorator_matchers=ignored_decorator_matchers, + target=target, ) fn_name = lambda path: path.split(":")[1] @@ -853,6 +859,7 @@ def main( translate_to_gluon: bool = False, output_path: str = "/tmp/reference.py", ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, + target: TranslatorTarget = TranslatorTarget.NVIDIA, ) -> None: output = slice_kernel( root_paths, @@ -861,6 +868,7 @@ def main( leaf_paths, translate_to_gluon, ignored_decorator_matchers, + target=target, ) with open(output_path, "w") as f: f.write(output) diff --git a/python/triton/tools/triton_to_gluon_translator/target.py b/python/triton/tools/triton_to_gluon_translator/target.py new file mode 100644 index 000000000000..f4c6c03dde1e --- /dev/null +++ b/python/triton/tools/triton_to_gluon_translator/target.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from enum import Enum + + +class TranslatorTarget(str, Enum): + """Target architecture for the Triton-to-Gluon translator. + + Known targets are listed as explicit members for discoverability. + Unknown ``gfx*`` strings are accepted via ``_missing_()`` so that + new AMD architectures work without adding an enum member. + """ + + NVIDIA = "nvidia" + # AMD targets currently exercised by the translator test suite: + GFX1250 = "gfx1250" + GFX942 = "gfx942" + GFX950 = "gfx950" + + @classmethod + def _missing_(cls, value: object) -> "TranslatorTarget | None": + """Allow any ``gfx*`` string as a valid AMD target.""" + if isinstance(value, str) and value.startswith("gfx"): + obj = str.__new__(cls, value) + obj._value_ = value + return obj + return None + + @property + def is_amd(self) -> bool: + return self != TranslatorTarget.NVIDIA + + @property + def tensor_descriptor_import(self) -> str: + """Return the import statement for the target's tensor descriptor module.""" + if self.is_amd: + return "from triton.experimental.gluon.language.amd.gfx1250.tdm import tensor_descriptor" + return "from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor" + + @property + def helpers_module(self) -> str: + """Return the helpers module path for this target.""" + base = "triton.tools.triton_to_gluon_translator" + if self.is_amd: + return f"{base}.amd_helpers" + return f"{base}.nvidia_helpers" diff --git a/python/triton/tools/triton_to_gluon_translator/translator.py b/python/triton/tools/triton_to_gluon_translator/translator.py index 4d7837931879..4e09e4679036 100644 --- a/python/triton/tools/triton_to_gluon_translator/translator.py +++ b/python/triton/tools/triton_to_gluon_translator/translator.py @@ -21,6 +21,7 @@ mangle_reference_names, parse_expr, ) +from triton.tools.triton_to_gluon_translator.target import TranslatorTarget from triton.tools.triton_to_gluon_translator.stable_toposort import stable_toposort @@ -108,6 +109,7 @@ def add_expr_rewrites(rewrites: list[RewriteFn]) -> None: @dataclass class Translator(ReferenceRewriter): tensor_member_match_fns: list[str] = field(default_factory=list) + target: TranslatorTarget = TranslatorTarget.NVIDIA def __post_init__(self) -> None: import triton @@ -123,7 +125,7 @@ def __post_init__(self) -> None: self.imports.add("import triton.experimental.gluon as gluon") self.imports.add("import triton.experimental.gluon.language as gl") - self.imports.add("import triton.tools.triton_to_gluon_translator.translator_helpers as helpers") + self.imports.add(f"import {self.target.helpers_module} as helpers") self.tensor_member_match_fns = ["reshape", "trans", "permute", "split", "reduce", "sum"] @@ -198,7 +200,7 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.AST: return self.generic_visit(node) -def translate_kernels(kernels: list[GlobalValue]) -> str: +def translate_kernels(kernels: list[GlobalValue], target: TranslatorTarget = TranslatorTarget.NVIDIA) -> str: def filter(value: ModuleType | GlobalValue) -> bool: if isinstance(value, ModuleType): @@ -240,6 +242,7 @@ def filter(value: ModuleType | GlobalValue) -> bool: imports, filter, value_remap={}, + target=target, ) tree = rewriter.visit(tree) source = ast.unparse(tree) @@ -250,12 +253,12 @@ def filter(value: ModuleType | GlobalValue) -> bool: return output -def translate_paths(kernel_paths: list[str]) -> str: +def translate_paths(kernel_paths: list[str], target: TranslatorTarget = TranslatorTarget.NVIDIA) -> str: kernels = [get_base_value(kernel_path) for kernel_path in kernel_paths] - return translate_kernels(kernels) + return translate_kernels(kernels, target=target) -def convert_triton_to_gluon(src: list[JITCallable]) -> str: +def convert_triton_to_gluon(src: list[JITCallable], target: TranslatorTarget = TranslatorTarget.NVIDIA) -> str: kernels = [ GlobalValue.wrap( kernel, @@ -263,11 +266,11 @@ def convert_triton_to_gluon(src: list[JITCallable]) -> str: lambda: builtins, ) for kernel in src ] - return translate_kernels(kernels) + return translate_kernels(kernels, target=target) -def main(kernels: list[str], output_path: str) -> None: - output = translate_paths(kernels) +def main(kernels: list[str], output_path: str, target: TranslatorTarget = TranslatorTarget.NVIDIA) -> None: + output = translate_paths(kernels, target=target) with open(output_path, "w") as f: f.write(output) @@ -276,8 +279,9 @@ def _main_cli() -> None: parser = argparse.ArgumentParser(description="Translate Triton kernels to Gluon source.") parser.add_argument("kernels", nargs="+", help="Kernel symbols in module.path:object format.") parser.add_argument("--output-path", required=True, help="Path to write the translated source.") + parser.add_argument("--target", default="nvidia", help="Target architecture (e.g. nvidia, amd_gfx1250).") args = parser.parse_args() - main(args.kernels, args.output_path) + main(args.kernels, args.output_path, target=TranslatorTarget(args.target)) if __name__ == "__main__":