From d97c57ae373da6a90233d42ea980ddb951870d7f Mon Sep 17 00:00:00 2001 From: Jokeren Date: Tue, 25 Jun 2024 19:06:54 -0400 Subject: [PATCH 1/8] Update --- python/test/unit/language/test_core.py | 12 ++-- python/test/unit/language/test_line_info.py | 37 ++++++++++- python/triton/compiler/code_generator.py | 72 ++++++--------------- python/triton/language/core.py | 8 ++- python/triton/runtime/interpreter.py | 64 +++++++++++++++++- python/triton/runtime/jit.py | 30 +++++++-- 6 files changed, 152 insertions(+), 71 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f453e6bdbb79..d898ca22483c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5365,12 +5365,12 @@ def test_tl_range(device): torch.testing.assert_close(ref_out, c, rtol=1e-2, atol=1e-1) else: torch.testing.assert_close(ref_out, c, rtol=1e-3, atol=1e-3) - if device in ['cuda']: - capability = torch.cuda.get_device_capability() - if capability[0] >= 8: - ptx = pgm.asm['ptx'] - # check that the loop got pipelined with the right number of stages. - assert 'cp.async.wait_group 0x6' in ptx + if device in ['cuda']: + capability = torch.cuda.get_device_capability() + if capability[0] >= 8: + ptx = pgm.asm['ptx'] + # check that the loop got pipelined with the right number of stages. + assert 'cp.async.wait_group 0x6' in ptx @triton.jit(noinline=True) diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index 6421c7309b18..b00e10d4b83a 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -118,9 +118,8 @@ def check_file_lines(file_lines, file_name, lineno, should_contain=True): should_contain: whether the file name and line number should be in the file_lines """ for file, line in file_lines: - if lineno == -1: - if file_name in file: - return True + if lineno == -1 and file_name in file: + return True if file_name in file and str(lineno) in line: return should_contain return not should_contain @@ -169,3 +168,35 @@ def test_line_info(func: str): elif func == "dot_combine": assert (check_file_lines(file_lines, "test_line_info.py", 65)) assert (check_file_lines(file_lines, "test_line_info.py", 66, should_contain=False)) + + +def is_interpreter(): + import os + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func", func_types) +def test_line_info_interpreter(func: str): + if not is_interpreter(): + pytest.skip("interpreter is not enabled") + + kernel = None + expected_offset = 0 + if func == "single": + kernel = kernel_single + expected_offset = 12 + elif func == "call": + kernel = kernel_call + expected_offset = 25 + elif func == "call_noinline": + kernel = kernel_call_noinline + expected_offset = 41 + elif func == "autotune": + kernel = kernel_autotune.fn + expected_offset = 52 + elif func == "dot_combine": + kernel = kernel_dot_combine + expected_offset = 62 + kernel._rewrite_ast() + assert kernel.ast_transformer.offset == expected_offset diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index a1444f932436..bc67faa09676 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -10,7 +10,7 @@ from .._C.libtriton import ir from ..language import constexpr, tensor, str_to_ty from ..language.core import _unwrap_if_constexpr -from ..runtime.jit import _normalize_ty +from ..runtime.jit import _normalize_ty, _get_fn_file_line # ideally we wouldn't need any runtime component from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) @@ -32,7 +32,7 @@ def mangle_ty(ty): return f'{elt}S{shape}S' if ty.is_void(): return 'V' - assert False, "Unsupported type" + raise TypeError(f'Unsupported type {ty}') def mangle_fn(name, arg_tys, constants): @@ -73,24 +73,6 @@ def _check_fn_args(node, fn, args): ) -def _get_fn_file_line(fn): - base_fn = fn - while not isinstance(base_fn, JITFunction): - base_fn = base_fn.fn - file_name = base_fn.fn.__code__.co_filename - lines, begin_line = inspect.getsourcelines(base_fn.fn) - # Match the following pattern: - # @triton.autotune(...) <- foo.__code__.co_firstlineno - # @triton.heuristics(...) - # @triton.jit - # def foo(...): <- this line is the first line - for idx, line in enumerate(lines): - if line.strip().startswith("def "): - begin_line += idx - break - return file_name, begin_line - - _condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels @@ -121,10 +103,7 @@ def __init__(self, gscope): self.gscope = gscope def _visit_stmts(self, body) -> bool: - for s in body: - if self.visit(s): - return True - return False + return any(self.visit(s) for s in body) def _visit_function(self, fn) -> bool: # Currently we only support JITFunctions defined in the global scope @@ -160,7 +139,7 @@ def visit_Attribute(self, node: ast.Attribute) -> bool: return self.visit(node.value) def visit_Name(self, node: ast.Name) -> bool: - if type(node.ctx) == ast.Store: + if type(node.ctx) is ast.Store: return False if node.id in self.gscope: fn = self.gscope[node.id] @@ -226,7 +205,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.function_ret_types = {} if function_types is None else function_types self.prototype = prototype self.gscope = gscope - self.lscope = dict() + self.lscope = {} self.attributes = attributes self.constants = constants self.jit_fn = jit_fn @@ -281,19 +260,11 @@ def global_lookup(name: str, absent): # The high-level rule is that only constexpr globals are allowed. # But actually a bunch of other things, such as module imports, are # technically Python globals. We have to allow these too! - if (val is absent # - or name in self.builtin_namespace # - or type(val) == ModuleType # - or isinstance(val, JITFunction) # - or getattr(val, "__triton_builtin__", False) # - or getattr(val, "__module__", "").startswith("triton.language") # - or isinstance(val, language.dtype) # - or self._is_constexpr_global(name) # - # Allow accesses to globals while visiting an ast.arg - # because you should be able to do - # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... - or self.visiting_arg_default_value # - or os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"): + if any(val is absent, name in self.builtin_namespace, + type(val) is ModuleType, isinstance(val, JITFunction), getattr(val, "__triton_builtin__", False), + getattr(val, "__module__", "").startswith("triton.language"), isinstance(val, language.dtype), + self._is_constexpr_global(name), self.visiting_arg_default_value, + os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"): return val raise NameError( textwrap.dedent(f"""\ @@ -418,7 +389,7 @@ def visit_FunctionDef(self, node): entry = self.fn.add_entry_block() arg_values = [] idx = 0 - for i, arg_name in enumerate(arg_names): + for i in range(len(arg_names)): if i in self.constants: cst = self.constants[i] if not _is_constexpr(cst): @@ -514,7 +485,7 @@ def visit_AugAssign(self, node): return self.dereference_name(name) def visit_Name(self, node): - if type(node.ctx) == ast.Store: + if type(node.ctx) is ast.Store: return node.id return self.dereference_name(node.id) @@ -770,9 +741,9 @@ def visit_Compare(self, node): rhs = self.visit(node.comparators[0]) lhs_value = _unwrap_if_constexpr(lhs) rhs_value = _unwrap_if_constexpr(rhs) - if type(node.ops[0]) == ast.Is: + if type(node.ops[0]) is ast.Is: return constexpr(lhs_value is rhs_value) - if type(node.ops[0]) == ast.IsNot: + if type(node.ops[0]) is ast.IsNot: return constexpr(lhs_value is not rhs_value) method_name = self._method_name_for_comp_op.get(type(node.ops[0])) if method_name is None: @@ -1048,7 +1019,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): args = [args[name] for name in fn.arg_names] args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args] # generate function def - attributes = dict() + attributes = {} constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] constants = {i: args[i] for i in constexprs} # generate call @@ -1098,14 +1069,14 @@ def visit_Call(self, node): kws = dict(self.visit(keyword) for keyword in node.keywords) args = [self.visit(arg) for arg in node.args] - if fn is language.core.device_assert: # TODO: this should not be so hardcoded - if not self.debug: - return + # TODO: this should not be so hardcoded + if fn is language.core.device_assert and not self.debug: + return if isinstance(fn, JITFunction): _check_fn_args(node, fn, args) return self.call_JitFunction(fn, args, kws) if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn): - extra_kwargs = dict(_builder=self.builder) + extra_kwargs = {"_builder": self.builder} sig = inspect.signature(fn) if '_generator' in sig.parameters: extra_kwargs['_generator'] = self @@ -1154,9 +1125,8 @@ def visit_Str(self, node): def visit_Attribute(self, node): lhs = self.visit(node.value) - if _is_triton_tensor(lhs): - if node.attr == "T": - return language.semantic.permute(lhs, (1, 0), builder=self.builder) + if _is_triton_tensor(lhs) and node.attr == "T": + return language.semantic.permute(lhs, (1, 0), builder=self.builder) return getattr(lhs, node.attr) def visit_Expr(self, node): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 4d17a7d5a01f..7e0027b0d850 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -115,7 +115,7 @@ def to_tensor(x, _builder=None): return _to_tensor(x, _builder) -def _to_tensor(x, builder): +def _to_tensor(x, builder, check_type: bool = True): if isinstance(x, bool): return tensor(builder.get_int1(x), int1) # Note: compile-time const integers are represented by unsigned values @@ -129,7 +129,7 @@ def _to_tensor(x, builder): elif 2**63 <= x < 2**64: return tensor(builder.get_uint64(x), uint64) else: - raise RuntimeError(f'Nonrepresentable integer {x}.') + raise ValueError(f'Nonrepresentable integer {x}.') elif isinstance(x, float): min_float32 = 2**-126 max_float32 = (2 - 2**-23) * 2**127 @@ -146,7 +146,9 @@ def _to_tensor(x, builder): return _to_tensor(x.value, builder) elif isinstance(x, tensor): return x - assert False, f"cannot convert {x} of type {type(x)} to tensor" + if check_type: + raise TypeError(f"cannot convert {x} of type {type(x)} to tensor") + return x # ----------------------- diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index cd0a12f25fab..734231a206cd 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -1,3 +1,5 @@ +import ast +import textwrap import inspect from typing import Tuple @@ -1094,28 +1096,86 @@ def __call__(self, *args_dev, **kwargs): self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) +class ASTTransformer(ast.NodeTransformer): + + def __init__(self) -> None: + self.offset = 0 + + def visit_Assign(self, node): + names = [] + for target in node.targets: + names += [self.visit(target)] + if len(names) > 1: + raise InterpreterError("Multiple assignments are not supported") + # Modify the assignment x = value to + # triton.core.language._to_tensor(value, interpreter_builder, False) + node.value = ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Attribute(value=ast.Name(id='triton', ctx=ast.Load()), attr='language', ctx=ast.Load()), + attr='core', ctx=ast.Load()), attr='_to_tensor', ctx=ast.Load()), + args=[node.value, ast.Name(id='interpreter_builder', ctx=ast.Load()), + ast.Constant(value=False)], keywords=[]) + return node + + def generic_visit(self, node): + if hasattr(node, 'lineno') and node.lineno is not None: + node.lineno += self.offset + if hasattr(node, 'end_lineno') and node.end_lineno is not None: + node.end_lineno += self.offset + return super().generic_visit(node) + + class InterpretedFunction: + rewritted_fn = {} + ast_transformer = ASTTransformer() - def __init__(self, fn) -> None: + def __init__(self, fn, **kwargs) -> None: self.fn = fn def run(*args, **kwargs): grid = kwargs["grid"] - return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs) + self._rewrite_ast() + return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs) self.run = run + self.kwargs = kwargs signature = inspect.signature(fn) self.arg_names = [v.name for v in signature.parameters.values()] + def _rewrite_ast(self): + if self.fn in self.rewritted_fn: + return + # If exception is raise, it means the function does not have source code available, + # e.g., dynamically generated functions, we cannot rewrite it and return the original function + try: + lines, lineno = inspect.getsourcelines(self.fn) + except Exception: + self.rewritted_fn[self.fn] = self.fn + return + from .jit import _get_fn_file_line, JITFunction + filename, lineno = _get_fn_file_line(JITFunction(self.fn)) + src = ''.join(lines) + src = textwrap.dedent(src) + parsed_ast = ast.parse(src) + self.ast_transformer.offset = lineno + transformed_ast = self.ast_transformer.visit(parsed_ast) + transformed_ast = ast.fix_missing_locations(transformed_ast) + compiled_code = compile(transformed_ast, filename=filename, mode='exec') + exec(compiled_code, self.fn.__globals__, self.kwargs) + self.rewritted_fn[self.fn] = self.fn + @property def __name__(self): return self.fn.__name__ def __getitem__(self, grid): + self._rewrite_ast() return GridExecutor(self.fn, self.arg_names, grid) def __call__(self, *args, **kwargs): # This is a device function call + self._rewrite_ast() _patch_lang(self.fn) try: return self.fn(*args, **kwargs) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index eca402e5bb43..8cc53c4065f8 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -98,7 +98,7 @@ def _update_hash(self, func): self.hasher.update(func_key.encode("utf-8")) def visit_Name(self, node): - if type(node.ctx) == ast.Store: + if type(node.ctx) is ast.Store: return node.id if node.id in self.local_names: @@ -117,12 +117,11 @@ def visit_Name(self, node): and not self.visiting_arg_default_value # It would be pretty evil if someone did `import x` and then # `x = blah`. - and type(val) != ModuleType + and type(val) is not ModuleType # It would be pretty evil if we used function `foo` inside of # `bar` and then someone did `foo = baz`. and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # - and node.id not in self.supported_python_builtins # - ): + and node.id not in self.supported_python_builtins): self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) self._update_hash(val) @@ -650,7 +649,7 @@ def run(self, *args, grid, warmup, **kwargs): # Check that used global values have not changed. not_present = object() - for (name, globals_dict_id), (val, globals_dict) in self.used_global_vals.items(): + for (name, _), (val, globals_dict) in self.used_global_vals.items(): if (newVal := globals_dict.get(name, not_present)) != val: raise RuntimeError( f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") @@ -848,7 +847,8 @@ def decorator(fn: T) -> JITFunction[T]: assert callable(fn) if os.getenv("TRITON_INTERPRET", "0") == "1": from .interpreter import InterpretedFunction - return InterpretedFunction(fn) + return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, debug=debug, + noinline=noinline, repr=repr, launch_metadata=launch_metadata) else: return JITFunction( fn, @@ -936,3 +936,21 @@ def reinterpret(tensor, dtype): return TensorWrapper(tensor, dtype) else: raise TypeError(f"Cannot reinterpret a {type(tensor)}.") + + +def _get_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line From 933bb6f9cf0a0fa440dfd8f34db9b52927a83e6e Mon Sep 17 00:00:00 2001 From: Jokeren Date: Tue, 25 Jun 2024 19:09:52 -0400 Subject: [PATCH 2/8] Update --- python/triton/runtime/interpreter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 734231a206cd..df8437bd723d 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -1106,7 +1106,7 @@ def visit_Assign(self, node): for target in node.targets: names += [self.visit(target)] if len(names) > 1: - raise InterpreterError("Multiple assignments are not supported") + raise ValueError("Multiple assignments are not supported") # Modify the assignment x = value to # triton.core.language._to_tensor(value, interpreter_builder, False) node.value = ast.Call( @@ -1119,6 +1119,7 @@ def visit_Assign(self, node): return node def generic_visit(self, node): + # Adjust the begin line number of the node if hasattr(node, 'lineno') and node.lineno is not None: node.lineno += self.offset if hasattr(node, 'end_lineno') and node.end_lineno is not None: From 8aca057817e12031f943d0fbd176d28a40555c61 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Tue, 25 Jun 2024 19:37:58 -0400 Subject: [PATCH 3/8] Update --- python/test/unit/language/test_core.py | 21 +++++++++++++++++++++ python/triton/runtime/interpreter.py | 25 ++++++++++++++----------- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d898ca22483c..554811339890 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1702,6 +1702,27 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): assert torch.all(output == ref) +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_store_constant_default_dtype(num_ctas, device): + """Tests that boolean True is stored as 1""" + + @triton.jit + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + value = 1 + output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) + tl.store(output_ptr + offsets, output, mask=mask) + + block_size = 128 + ref = torch.ones([block_size], dtype=getattr(torch, 'int32'), device=device) + output = torch.zeros([block_size], dtype=getattr(torch, 'int32'), device=device) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) + + assert torch.all(output == ref) + + def test_load_store_same_ptr(device): @triton.jit() diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index df8437bd723d..aa36290984fc 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -1136,7 +1136,7 @@ def __init__(self, fn, **kwargs) -> None: def run(*args, **kwargs): grid = kwargs["grid"] - self._rewrite_ast() + fn = self._rewrite_ast() return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs) self.run = run @@ -1146,14 +1146,14 @@ def run(*args, **kwargs): def _rewrite_ast(self): if self.fn in self.rewritted_fn: - return + return self.rewritted_fn[self.fn] # If exception is raise, it means the function does not have source code available, - # e.g., dynamically generated functions, we cannot rewrite it and return the original function + # e.g., dynamically generated functions, we cannot rewrite it so just return the original function try: lines, lineno = inspect.getsourcelines(self.fn) except Exception: self.rewritted_fn[self.fn] = self.fn - return + return self.fn from .jit import _get_fn_file_line, JITFunction filename, lineno = _get_fn_file_line(JITFunction(self.fn)) src = ''.join(lines) @@ -1163,22 +1163,25 @@ def _rewrite_ast(self): transformed_ast = self.ast_transformer.visit(parsed_ast) transformed_ast = ast.fix_missing_locations(transformed_ast) compiled_code = compile(transformed_ast, filename=filename, mode='exec') - exec(compiled_code, self.fn.__globals__, self.kwargs) - self.rewritted_fn[self.fn] = self.fn + local_namespace = {**self.kwargs} + exec(compiled_code, globals(), local_namespace) + fn = local_namespace[self.fn.__name__].fn + self.rewritted_fn[self.fn] = fn + return fn @property def __name__(self): return self.fn.__name__ def __getitem__(self, grid): - self._rewrite_ast() - return GridExecutor(self.fn, self.arg_names, grid) + fn = self._rewrite_ast() + return GridExecutor(fn, self.arg_names, grid) def __call__(self, *args, **kwargs): # This is a device function call - self._rewrite_ast() - _patch_lang(self.fn) + fn = self._rewrite_ast() + _patch_lang(fn) try: - return self.fn(*args, **kwargs) + return fn(*args, **kwargs) except Exception as e: raise InterpreterError(repr(e)) from e From 7e757495cb28a9d7060f37f61ff5e361544a0abb Mon Sep 17 00:00:00 2001 From: Jokeren Date: Tue, 25 Jun 2024 19:38:55 -0400 Subject: [PATCH 4/8] Update --- python/triton/compiler/code_generator.py | 6 +++--- python/triton/runtime/interpreter.py | 4 ++-- python/triton/runtime/jit.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index bc67faa09676..e1b48fa3000b 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -10,7 +10,7 @@ from .._C.libtriton import ir from ..language import constexpr, tensor, str_to_ty from ..language.core import _unwrap_if_constexpr -from ..runtime.jit import _normalize_ty, _get_fn_file_line +from ..runtime.jit import _normalize_ty, get_fn_file_line # ideally we wouldn't need any runtime component from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) @@ -1032,7 +1032,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): prototype = language.function_type([], arg_types) gscope = fn.__globals__ # If the callee is not set, we use the same debug setting as the caller - file_name, begin_line = _get_fn_file_line(fn) + file_name, begin_line = get_fn_file_line(fn) debug = self.debug if fn.debug is None else fn.debug generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, @@ -1255,7 +1255,7 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns): all_constants = constants.copy() all_constants.update(new_constants) arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] - file_name, begin_line = _get_fn_file_line(fn) + file_name, begin_line = get_fn_file_line(fn) prototype = language.function_type([], arg_types) generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index aa36290984fc..01f61d6e5fd4 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -1154,8 +1154,8 @@ def _rewrite_ast(self): except Exception: self.rewritted_fn[self.fn] = self.fn return self.fn - from .jit import _get_fn_file_line, JITFunction - filename, lineno = _get_fn_file_line(JITFunction(self.fn)) + from .jit import get_fn_file_line, JITFunction + filename, lineno = get_fn_file_line(JITFunction(self.fn)) src = ''.join(lines) src = textwrap.dedent(src) parsed_ast = ast.parse(src) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 8cc53c4065f8..193a9f81267a 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -938,7 +938,7 @@ def reinterpret(tensor, dtype): raise TypeError(f"Cannot reinterpret a {type(tensor)}.") -def _get_fn_file_line(fn): +def get_fn_file_line(fn): base_fn = fn while not isinstance(base_fn, JITFunction): base_fn = base_fn.fn From efbb972c16869ca8321eb156c786c662b375b568 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Tue, 25 Jun 2024 19:39:35 -0400 Subject: [PATCH 5/8] Update --- python/triton/compiler/code_generator.py | 6 +++--- python/triton/runtime/interpreter.py | 4 ++-- python/triton/runtime/jit.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index e1b48fa3000b..cfcb45d28d15 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -10,7 +10,7 @@ from .._C.libtriton import ir from ..language import constexpr, tensor, str_to_ty from ..language.core import _unwrap_if_constexpr -from ..runtime.jit import _normalize_ty, get_fn_file_line +from ..runtime.jit import _normalize_ty, get_jit_fn_file_line # ideally we wouldn't need any runtime component from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) @@ -1032,7 +1032,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): prototype = language.function_type([], arg_types) gscope = fn.__globals__ # If the callee is not set, we use the same debug setting as the caller - file_name, begin_line = get_fn_file_line(fn) + file_name, begin_line = get_jit_fn_file_line(fn) debug = self.debug if fn.debug is None else fn.debug generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, @@ -1255,7 +1255,7 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns): all_constants = constants.copy() all_constants.update(new_constants) arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] - file_name, begin_line = get_fn_file_line(fn) + file_name, begin_line = get_jit_fn_file_line(fn) prototype = language.function_type([], arg_types) generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 01f61d6e5fd4..dafe9367bf67 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -1154,8 +1154,8 @@ def _rewrite_ast(self): except Exception: self.rewritted_fn[self.fn] = self.fn return self.fn - from .jit import get_fn_file_line, JITFunction - filename, lineno = get_fn_file_line(JITFunction(self.fn)) + from .jit import get_jit_fn_file_line, JITFunction + filename, lineno = get_jit_fn_file_line(JITFunction(self.fn)) src = ''.join(lines) src = textwrap.dedent(src) parsed_ast = ast.parse(src) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 193a9f81267a..0e24a5e74d7f 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -938,7 +938,7 @@ def reinterpret(tensor, dtype): raise TypeError(f"Cannot reinterpret a {type(tensor)}.") -def get_fn_file_line(fn): +def get_jit_fn_file_line(fn): base_fn = fn while not isinstance(base_fn, JITFunction): base_fn = base_fn.fn From c63625c54e6d92a6040052f6d0cf0811ab4e4280 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Tue, 25 Jun 2024 21:28:38 -0400 Subject: [PATCH 6/8] Update --- docs/programming-guide/chapter-3/debugging.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/programming-guide/chapter-3/debugging.rst b/docs/programming-guide/chapter-3/debugging.rst index 31e92d2822eb..c470363c6409 100644 --- a/docs/programming-guide/chapter-3/debugging.rst +++ b/docs/programming-guide/chapter-3/debugging.rst @@ -70,8 +70,6 @@ The interpreter has several known limitations: ptr = tl.load(ptr) x = tl.load(ptr) -- Unlike the compilation mode, a scalar in interpreter mode is treated as a simple float or integer but not as a 0-d tensor. This means it lacks tensor attributes such as :code:`x.dtype`. A workaround is to explicitly convert the scalar to a tensor using :code:`tl.to_tensor(x)`, where :code:`x` is the scalar. - ---------------------------- Using Third-party Tools ---------------------------- From 2ec75451c5ca8c8df82101553219ab3e54ec3dab Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 26 Jun 2024 15:26:29 -0400 Subject: [PATCH 7/8] Update --- python/triton/runtime/interpreter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index dafe9367bf67..7b45ae14203c 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -1179,8 +1179,8 @@ def __getitem__(self, grid): def __call__(self, *args, **kwargs): # This is a device function call + _patch_lang(self.fn) fn = self._rewrite_ast() - _patch_lang(fn) try: return fn(*args, **kwargs) except Exception as e: From 44ff44390dbefb8ebcfcdbbc19b4c346535bcf0f Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 26 Jun 2024 19:40:42 -0400 Subject: [PATCH 8/8] Update --- python/triton/runtime/interpreter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 7b45ae14203c..26e476ebebec 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -1164,6 +1164,8 @@ def _rewrite_ast(self): transformed_ast = ast.fix_missing_locations(transformed_ast) compiled_code = compile(transformed_ast, filename=filename, mode='exec') local_namespace = {**self.kwargs} + if self.fn.__name__ in local_namespace: + raise ValueError(f"Function name {self.fn.__name__} is reserved") exec(compiled_code, globals(), local_namespace) fn = local_namespace[self.fn.__name__].fn self.rewritted_fn[self.fn] = fn