diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 8dada655c4b7..fbca966c9c93 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -1874,6 +1874,44 @@ def test_atomic_rmw(): ttgl.atomic_add(offset + ptr, val, mask=scalar_mask, sem="acquire", scope="cta") +@filecheck_test +@gluon.jit +def test_atomic_rmw_scalar_masks(): + # CHECK-LABEL: test_atomic_rmw_scalar_masks + BLOCK: ttgl.constexpr = 128 + x = ttgl.full([BLOCK], 0, ttgl.int64) + ptr = x.cast(ttgl.pointer_type(ttgl.int32), bitcast=True) + offs = ttgl.arange(0, BLOCK) + ptrs = ptr + offs + val = ttgl.full([BLOCK], 1, ttgl.int32) + mask = offs >= 0 + scalar_mask = True + constexpr_value: ttgl.constexpr = 1 + constexpr_mask: ttgl.constexpr = True + + # CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu + ttgl.atomic_add(ptrs, val, mask=mask) + # CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu + ttgl.atomic_add(ptrs, 1, mask=True) + # CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu + ttgl.atomic_add(ptrs, constexpr_value, mask=constexpr_mask) + # CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu + ttgl.atomic_add(ptrs, val, mask=scalar_mask) + + # CHECK: {{.*}} = tt.atomic_rmw exch, acq_rel, gpu + ttgl.atomic_xchg(ptrs, 1, mask=True) + # CHECK: {{.*}} = tt.atomic_rmw max, acq_rel, gpu + ttgl.atomic_max(ptrs, 1, mask=True) + # CHECK: {{.*}} = tt.atomic_rmw min, acq_rel, gpu + ttgl.atomic_min(ptrs, 1, mask=True) + # CHECK: {{.*}} = tt.atomic_rmw and, acq_rel, gpu + ttgl.atomic_and(ptrs, 1, mask=True) + # CHECK: {{.*}} = tt.atomic_rmw or, acq_rel, gpu + ttgl.atomic_or(ptrs, 1, mask=True) + # CHECK: {{.*}} = tt.atomic_rmw xor, acq_rel, gpu + ttgl.atomic_xor(ptrs, 1, mask=True) + + @filecheck_test @gluon.jit def test_atomic_cas(): diff --git a/python/test/unit/language/test_frontend.py b/python/test/unit/language/test_frontend.py index 38a6f7786e41..70938f14a677 100644 --- a/python/test/unit/language/test_frontend.py +++ b/python/test/unit/language/test_frontend.py @@ -629,6 +629,43 @@ def test(): run_parser(test) +@filecheck_test +@triton.jit +def test_atomic_scalar_masks(): + # CHECK-LABEL: test_atomic_scalar_masks + BLOCK: tl.constexpr = 128 + ptr = tl.full((BLOCK, ), 0, tl.int64).to(tl.pointer_type(tl.int32), bitcast=True) + offs = tl.arange(0, BLOCK) + ptrs = ptr + offs + val = tl.full((BLOCK, ), 1, tl.int32) + mask = offs >= 0 + scalar_mask = True + constexpr_value: tl.constexpr = 1 + constexpr_mask: tl.constexpr = True + + # CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu + tl.atomic_add(ptrs, val, mask=mask) + # CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu + tl.atomic_add(ptrs, 1, mask=True) + # CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu + tl.atomic_add(ptrs, constexpr_value, mask=constexpr_mask) + # CHECK: {{.*}} = tt.atomic_rmw add, acq_rel, gpu + tl.atomic_add(ptrs, val, mask=scalar_mask) + + # CHECK: {{.*}} = tt.atomic_rmw exch, acq_rel, gpu + tl.atomic_xchg(ptrs, 1, mask=True) + # CHECK: {{.*}} = tt.atomic_rmw max, acq_rel, gpu + tl.atomic_max(ptrs, 1, mask=True) + # CHECK: {{.*}} = tt.atomic_rmw min, acq_rel, gpu + tl.atomic_min(ptrs, 1, mask=True) + # CHECK: {{.*}} = tt.atomic_rmw and, acq_rel, gpu + tl.atomic_and(ptrs, 1, mask=True) + # CHECK: {{.*}} = tt.atomic_rmw or, acq_rel, gpu + tl.atomic_or(ptrs, 1, mask=True) + # CHECK: {{.*}} = tt.atomic_rmw xor, acq_rel, gpu + tl.atomic_xor(ptrs, 1, mask=True) + + @pytest.mark.interpreter def test_return_promotion(): diff --git a/python/test/unit/tools/test_slice_kernel.py b/python/test/unit/tools/test_slice_kernel.py index e2033a42cdbc..0923dec9d667 100644 --- a/python/test/unit/tools/test_slice_kernel.py +++ b/python/test/unit/tools/test_slice_kernel.py @@ -9,7 +9,7 @@ import pytest from triton.tools.triton_to_gluon_translator.ordered_set import ordered_set -from triton.tools.triton_to_gluon_translator.slice_kernel import get_reference, slice_kernel +from triton.tools.triton_to_gluon_translator.slice_kernel import RewriteSpec, get_reference, slice_kernel from triton.tools.triton_to_gluon_translator.stable_toposort import stable_toposort @@ -154,7 +154,7 @@ def matcher(context, cur_module, decorator): top = slice_kernel( [f"{mod('kernel_mod')}:kernel_top"], ["triton", "torch"], - ignored_decorator_matchers=[matcher], + rewrite_spec=RewriteSpec(ignored_decorator_matchers=[matcher]), ) assert "@keep()" in top assert "@mock_kernel" not in top @@ -162,7 +162,7 @@ def matcher(context, cur_module, decorator): bottom = slice_kernel( [f"{mod('kernel_mod')}:kernel_bottom"], ["triton", "torch"], - ignored_decorator_matchers=[matcher], + rewrite_spec=RewriteSpec(ignored_decorator_matchers=[matcher]), ) assert "@keep()" not in bottom assert "@mock_kernel" not in bottom diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 5938f912d24f..635645cde3cd 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1178,6 +1178,7 @@ def _broadcast_ptr_val_mask(self, ptr, val, mask): if mask is None: ptr, val = self.broadcast_tensors(ptr, val) else: + mask = self.to_tensor(mask) ptr, val, mask = self.broadcast_tensors(ptr, val, mask) if ptr_shape != ptr.shape: raise ValueError(f"Expected pointer argument to have shape {ptr.shape} but got {ptr_shape}") @@ -1266,6 +1267,8 @@ def atom_red_typechecking_impl(self, ptr: TensorTy, val: TensorTy, mask: TensorT mask_ty = ptr.type.with_element_ty(tl.int1) mask_ir = self.builder.create_splat(mask_ty.to_ir(self.builder), mask_ir) mask = self.tensor(mask_ir, mask_ty) + elif not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") return ptr, val, mask def _signbit(self, x: TensorTy) -> TensorTy: diff --git a/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py b/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py index 8791e12df1ab..09245179e5de 100644 --- a/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py +++ b/python/triton/tools/triton_to_gluon_translator/nvidia_helpers.py @@ -420,22 +420,8 @@ def tl_load_tensor_descriptor(desc, offsets): return out -# ---- NVIDIA obj dispatch ---- - - -@gluon.jit -def tl_obj_store(obj, offsets, value): - tl_store_tensor_descriptor(obj, offsets, value) - - -@gluon.jit -def tl_obj_load(obj, offsets): - return tl_load_tensor_descriptor(obj, offsets) - - @gluon.jit -def tl_obj_gather(obj, x_offsets, y_offset): - desc = obj +def tl_gather_tensor_descriptor(desc, x_offsets, y_offset): desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout) bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) @@ -445,7 +431,7 @@ def tl_obj_gather(obj, x_offsets, y_offset): ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]), ) x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) - mbarrier.expect(bar, x_offsets.shape[0] * obj.block_type.nbytes) + mbarrier.expect(bar, x_offsets.shape[0] * desc.block_type.nbytes) tma_blackwell.async_gather(desc, x_offsets, y_offset, bar, alloc) mbarrier.wait(bar, phase=0) mbarrier.invalidate(bar) @@ -455,8 +441,7 @@ def tl_obj_gather(obj, x_offsets, y_offset): @gluon.jit -def tl_obj_scatter(obj, value, x_offsets, y_offset): - desc = obj +def tl_scatter_tensor_descriptor(desc, value, x_offsets, y_offset): desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout, value) fence_async_shared() @@ -469,6 +454,41 @@ def tl_obj_scatter(obj, value, x_offsets, y_offset): tma.store_wait(0) +# ---- NVIDIA obj dispatch ---- + + +@gluon.jit +def tl_obj_store(obj, offsets, value): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + return tl_store_tensor_descriptor(obj, offsets, value) + else: + return obj.store(offsets, value) + + +@gluon.jit +def tl_obj_load(obj, offsets): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + return tl_load_tensor_descriptor(obj, offsets) + else: + return obj.load(offsets) + + +@gluon.jit +def tl_obj_gather(obj, x_offsets, y_offset): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + return tl_gather_tensor_descriptor(obj, x_offsets, y_offset) + else: + return obj.gather(x_offsets, y_offset) + + +@gluon.jit +def tl_obj_scatter(obj, value, x_offsets, y_offset): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + return tl_scatter_tensor_descriptor(obj, value, x_offsets, y_offset) + else: + return obj.scatter(value, x_offsets, y_offset) + + # ---- NVIDIA host-side descriptor ---- diff --git a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py index cd40f4921047..650378f0b858 100644 --- a/python/triton/tools/triton_to_gluon_translator/slice_kernel.py +++ b/python/triton/tools/triton_to_gluon_translator/slice_kernel.py @@ -121,6 +121,13 @@ def mangle_source(self, source: str, mangled_name: str) -> str: FilterFn = Callable[[ModuleType | GlobalValue], bool] DecoratorMatcher: TypeAlias = Callable[[scoped_dict[str, Any], ModuleType, ast.expr], bool] +AnnotationRewriter: TypeAlias = Callable[[scoped_dict[str, Any], ModuleType, ast.Subscript], ast.expr | None] + + +@dataclass +class RewriteSpec: + ignored_decorator_matchers: Sequence[DecoratorMatcher] = field(default_factory=tuple) + annotation_rewriters: Sequence[AnnotationRewriter] = field(default_factory=tuple) def get_assign_target(stmt: ast.Assign | ast.AnnAssign) -> ast.Name | None: @@ -240,9 +247,9 @@ def is_ignored_decorator( context: scoped_dict[str, Any], cur_module: ModuleType, decorator: ast.expr, - ignored_decorator_matchers: Sequence[DecoratorMatcher], + rewrite_spec: RewriteSpec, ) -> bool: - return any(matcher(context, cur_module, decorator) for matcher in ignored_decorator_matchers) + return any(matcher(context, cur_module, decorator) for matcher in rewrite_spec.ignored_decorator_matchers) @dataclass @@ -266,7 +273,7 @@ class ReferenceScanner(ast.NodeVisitor): queue: list[GlobalValue] value_remap: dict[int, GlobalValue] filter: FilterFn - ignored_decorator_matchers: Sequence[DecoratorMatcher] = field(default_factory=tuple) + rewrite_spec: RewriteSpec = field(default_factory=RewriteSpec) edges: ordered_set[int] = field(default_factory=ordered_set[int]) @@ -311,7 +318,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: self.context, self.cur_module, decorator, - self.ignored_decorator_matchers, + self.rewrite_spec, ) ] args = node.args @@ -430,7 +437,7 @@ class ReferenceRewriter(ast.NodeTransformer): imports: ordered_set[str] filter: FilterFn value_remap: dict[int, GlobalValue] - ignored_decorator_matchers: Sequence[DecoratorMatcher] = field(default_factory=tuple) + rewrite_spec: RewriteSpec = field(default_factory=RewriteSpec) rewrites: list[RewriteFn] = field(default_factory=list) @@ -486,6 +493,13 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST: value, rel_module, name = ref return self.process_reference(node, name, value, rel_module) + def visit_Subscript(self, node: ast.Subscript) -> ast.AST: + for rewriter in self.rewrite_spec.annotation_rewriters: + replacement = rewriter(self.context, self.cur_module, node) + if replacement is not None: + return ast.copy_location(self.visit(replacement), node) + return self.generic_visit(node) + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: args = node.args with self.context.scope(): @@ -593,7 +607,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: self.context, self.cur_module, decorator, - self.ignored_decorator_matchers, + self.rewrite_spec, ): new_decorators = [] continue @@ -621,12 +635,12 @@ def find_references( base_values: list[GlobalValue], filter: FilterFn, value_remap: dict[int, GlobalValue], - ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, + rewrite_spec: RewriteSpec | None = None, ) -> tuple[OrderedDict[int, Reference], dict[int, ordered_set[int]]]: references: OrderedDict[int, Reference] = OrderedDict() queue: list[GlobalValue] = [] graph: dict[int, ordered_set[int]] = {} - ignored_decorator_matchers = tuple(ignored_decorator_matchers or ()) + rewrite_spec = rewrite_spec or RewriteSpec() for base_value in base_values: base_value = value_remap.get(base_value.id, base_value) @@ -647,7 +661,7 @@ def find_references( queue, value_remap, filter, - ignored_decorator_matchers=ignored_decorator_matchers, + rewrite_spec=rewrite_spec, ) tree = value.parse_ast() scanner.visit(tree) @@ -675,7 +689,7 @@ def mangle_reference_names(references: OrderedDict[int, Reference], filter: Filt def find_jit_functions( base_values: list[GlobalValue], filter: FilterFn, - ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, + rewrite_spec: RewriteSpec | None = None, ) -> list[GlobalValue]: def new_filter(value: ModuleType | GlobalValue) -> bool: @@ -687,7 +701,7 @@ def new_filter(value: ModuleType | GlobalValue) -> bool: base_values, new_filter, value_remap={}, - ignored_decorator_matchers=ignored_decorator_matchers, + rewrite_spec=rewrite_spec, ) return [ reference.value for reference in references.values() if isinstance(reference.value.original_value, JITFunction) @@ -710,9 +724,10 @@ def slice_kernel( include_below: list[str] | None = None, leaf_paths: list[str] | None = None, translate_to_gluon: bool = False, - ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, + rewrite_spec: RewriteSpec | None = None, target: TranslatorTarget = TranslatorTarget.NVIDIA, ) -> str: + rewrite_spec = rewrite_spec or RewriteSpec() base_values: list[GlobalValue] = [get_base_value(root_path) for root_path in root_paths] base_value_ids: set[int] = set() for leaf_path in leaf_paths or []: @@ -736,7 +751,7 @@ def filter(value: ModuleType | GlobalValue) -> bool: jit_functions = find_jit_functions( base_values, filter, - ignored_decorator_matchers=ignored_decorator_matchers, + rewrite_spec=rewrite_spec, ) jit_functions = [fn for fn in jit_functions if not fn.original_value.is_gluon()] converted_functions = translate_kernels(jit_functions, target=target) @@ -753,7 +768,7 @@ def filter(value: ModuleType | GlobalValue) -> bool: base_values, filter, value_remap, - ignored_decorator_matchers=ignored_decorator_matchers, + rewrite_spec=rewrite_spec, ) mangle_reference_names(references, filter) @@ -782,7 +797,7 @@ def filter(value: ModuleType | GlobalValue) -> bool: imports, filter, value_remap, - ignored_decorator_matchers=tuple(ignored_decorator_matchers or ()), + rewrite_spec=rewrite_spec, translate_to_gluon=translate_to_gluon, inline_helpers=inline_helpers, target=target, @@ -810,7 +825,7 @@ def slice_kernel_from_trace( trace: list[dict[str, list[str]]], translate_to_gluon: bool, extra_modules: dict[str, str], - ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, + rewrite_spec: RewriteSpec | None = None, target: TranslatorTarget = TranslatorTarget.NVIDIA, ) -> str: module_remap: dict[str, str] = {} @@ -835,7 +850,7 @@ def slice_kernel_from_trace( leaf_modules=["triton", "torch", "ki.spo"], leaf_paths=sorted(leaf_paths), translate_to_gluon=translate_to_gluon, - ignored_decorator_matchers=ignored_decorator_matchers, + rewrite_spec=rewrite_spec, target=target, ) @@ -858,16 +873,14 @@ def main( leaf_paths: list[str] | None = None, translate_to_gluon: bool = False, output_path: str = "/tmp/reference.py", - ignored_decorator_matchers: Sequence[DecoratorMatcher] | None = None, target: TranslatorTarget = TranslatorTarget.NVIDIA, ) -> None: output = slice_kernel( - root_paths, - leaf_modules, - include_below, - leaf_paths, - translate_to_gluon, - ignored_decorator_matchers, + root_paths=root_paths, + leaf_modules=leaf_modules, + include_below=include_below, + leaf_paths=leaf_paths, + translate_to_gluon=translate_to_gluon, target=target, ) with open(output_path, "w") as f: