diff --git a/docs/programming-guide/chapter-3/debugging.rst b/docs/programming-guide/chapter-3/debugging.rst index 31e92d2822eb..c470363c6409 100644 --- a/docs/programming-guide/chapter-3/debugging.rst +++ b/docs/programming-guide/chapter-3/debugging.rst @@ -70,8 +70,6 @@ The interpreter has several known limitations: ptr = tl.load(ptr) x = tl.load(ptr) -- Unlike the compilation mode, a scalar in interpreter mode is treated as a simple float or integer but not as a 0-d tensor. This means it lacks tensor attributes such as :code:`x.dtype`. A workaround is to explicitly convert the scalar to a tensor using :code:`tl.to_tensor(x)`, where :code:`x` is the scalar. - ---------------------------- Using Third-party Tools ---------------------------- diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 41f47bff8121..45066956b3d6 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1702,6 +1702,27 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): assert torch.all(output == ref) +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_store_constant_default_dtype(num_ctas, device): + """Tests that boolean True is stored as 1""" + + @triton.jit + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + value = 1 + output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) + tl.store(output_ptr + offsets, output, mask=mask) + + block_size = 128 + ref = torch.ones([block_size], dtype=getattr(torch, 'int32'), device=device) + output = torch.zeros([block_size], dtype=getattr(torch, 'int32'), device=device) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) + + assert torch.all(output == ref) + + def test_load_store_same_ptr(device): @triton.jit() @@ -5334,12 +5355,12 @@ def test_tl_range(device): torch.testing.assert_close(ref_out, c, rtol=1e-2, atol=1e-1) else: torch.testing.assert_close(ref_out, c, rtol=1e-3, atol=1e-3) - if device in ['cuda']: - capability = torch.cuda.get_device_capability() - if capability[0] >= 8: - ptx = pgm.asm['ptx'] - # check that the loop got pipelined with the right number of stages. - assert 'cp.async.wait_group 0x6' in ptx + if device in ['cuda']: + capability = torch.cuda.get_device_capability() + if capability[0] >= 8: + ptx = pgm.asm['ptx'] + # check that the loop got pipelined with the right number of stages. + assert 'cp.async.wait_group 0x6' in ptx @triton.jit(noinline=True) diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index 6421c7309b18..b00e10d4b83a 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -118,9 +118,8 @@ def check_file_lines(file_lines, file_name, lineno, should_contain=True): should_contain: whether the file name and line number should be in the file_lines """ for file, line in file_lines: - if lineno == -1: - if file_name in file: - return True + if lineno == -1 and file_name in file: + return True if file_name in file and str(lineno) in line: return should_contain return not should_contain @@ -169,3 +168,35 @@ def test_line_info(func: str): elif func == "dot_combine": assert (check_file_lines(file_lines, "test_line_info.py", 65)) assert (check_file_lines(file_lines, "test_line_info.py", 66, should_contain=False)) + + +def is_interpreter(): + import os + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func", func_types) +def test_line_info_interpreter(func: str): + if not is_interpreter(): + pytest.skip("interpreter is not enabled") + + kernel = None + expected_offset = 0 + if func == "single": + kernel = kernel_single + expected_offset = 12 + elif func == "call": + kernel = kernel_call + expected_offset = 25 + elif func == "call_noinline": + kernel = kernel_call_noinline + expected_offset = 41 + elif func == "autotune": + kernel = kernel_autotune.fn + expected_offset = 52 + elif func == "dot_combine": + kernel = kernel_dot_combine + expected_offset = 62 + kernel._rewrite_ast() + assert kernel.ast_transformer.offset == expected_offset diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 917aef2a20d7..a642e53da4ac 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -10,7 +10,7 @@ from .._C.libtriton import ir from ..language import constexpr, tensor, str_to_ty from ..language.core import _unwrap_if_constexpr -from ..runtime.jit import _normalize_ty +from ..runtime.jit import _normalize_ty, get_jit_fn_file_line # ideally we wouldn't need any runtime component from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) @@ -73,24 +73,6 @@ def _check_fn_args(node, fn, args): ) -def _get_fn_file_line(fn): - base_fn = fn - while not isinstance(base_fn, JITFunction): - base_fn = base_fn.fn - file_name = base_fn.fn.__code__.co_filename - lines, begin_line = inspect.getsourcelines(base_fn.fn) - # Match the following pattern: - # @triton.autotune(...) <- foo.__code__.co_firstlineno - # @triton.heuristics(...) - # @triton.jit - # def foo(...): <- this line is the first line - for idx, line in enumerate(lines): - if line.strip().startswith("def "): - begin_line += idx - break - return file_name, begin_line - - _condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels @@ -1059,7 +1041,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): prototype = language.function_type([], arg_types) gscope = fn.__globals__ # If the callee is not set, we use the same debug setting as the caller - file_name, begin_line = _get_fn_file_line(fn) + file_name, begin_line = get_jit_fn_file_line(fn) debug = self.debug if fn.debug is None else fn.debug generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, @@ -1282,7 +1264,7 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns): all_constants = constants.copy() all_constants.update(new_constants) arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] - file_name, begin_line = _get_fn_file_line(fn) + file_name, begin_line = get_jit_fn_file_line(fn) prototype = language.function_type([], arg_types) generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e4c047f99477..f3b17e4d9c47 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -115,7 +115,7 @@ def to_tensor(x, _builder=None): return _to_tensor(x, _builder) -def _to_tensor(x, builder): +def _to_tensor(x, builder, check_type: bool = True): if isinstance(x, bool): return tensor(builder.get_int1(x), int1) # Note: compile-time const integers are represented by unsigned values @@ -129,7 +129,7 @@ def _to_tensor(x, builder): elif 2**63 <= x < 2**64: return tensor(builder.get_uint64(x), uint64) else: - raise RuntimeError(f'Nonrepresentable integer {x}.') + raise ValueError(f'Nonrepresentable integer {x}.') elif isinstance(x, float): min_float32 = 2**-126 max_float32 = (2 - 2**-23) * 2**127 @@ -146,7 +146,9 @@ def _to_tensor(x, builder): return _to_tensor(x.value, builder) elif isinstance(x, tensor): return x - assert False, f"cannot convert {x} of type {type(x)} to tensor" + if check_type: + raise TypeError(f"cannot convert {x} of type {type(x)} to tensor") + return x # ----------------------- diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index cd0a12f25fab..26e476ebebec 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -1,3 +1,5 @@ +import ast +import textwrap import inspect from typing import Tuple @@ -1094,30 +1096,94 @@ def __call__(self, *args_dev, **kwargs): self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) +class ASTTransformer(ast.NodeTransformer): + + def __init__(self) -> None: + self.offset = 0 + + def visit_Assign(self, node): + names = [] + for target in node.targets: + names += [self.visit(target)] + if len(names) > 1: + raise ValueError("Multiple assignments are not supported") + # Modify the assignment x = value to + # triton.core.language._to_tensor(value, interpreter_builder, False) + node.value = ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Attribute(value=ast.Name(id='triton', ctx=ast.Load()), attr='language', ctx=ast.Load()), + attr='core', ctx=ast.Load()), attr='_to_tensor', ctx=ast.Load()), + args=[node.value, ast.Name(id='interpreter_builder', ctx=ast.Load()), + ast.Constant(value=False)], keywords=[]) + return node + + def generic_visit(self, node): + # Adjust the begin line number of the node + if hasattr(node, 'lineno') and node.lineno is not None: + node.lineno += self.offset + if hasattr(node, 'end_lineno') and node.end_lineno is not None: + node.end_lineno += self.offset + return super().generic_visit(node) + + class InterpretedFunction: + rewritted_fn = {} + ast_transformer = ASTTransformer() - def __init__(self, fn) -> None: + def __init__(self, fn, **kwargs) -> None: self.fn = fn def run(*args, **kwargs): grid = kwargs["grid"] - return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs) + fn = self._rewrite_ast() + return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs) self.run = run + self.kwargs = kwargs signature = inspect.signature(fn) self.arg_names = [v.name for v in signature.parameters.values()] + def _rewrite_ast(self): + if self.fn in self.rewritted_fn: + return self.rewritted_fn[self.fn] + # If exception is raise, it means the function does not have source code available, + # e.g., dynamically generated functions, we cannot rewrite it so just return the original function + try: + lines, lineno = inspect.getsourcelines(self.fn) + except Exception: + self.rewritted_fn[self.fn] = self.fn + return self.fn + from .jit import get_jit_fn_file_line, JITFunction + filename, lineno = get_jit_fn_file_line(JITFunction(self.fn)) + src = ''.join(lines) + src = textwrap.dedent(src) + parsed_ast = ast.parse(src) + self.ast_transformer.offset = lineno + transformed_ast = self.ast_transformer.visit(parsed_ast) + transformed_ast = ast.fix_missing_locations(transformed_ast) + compiled_code = compile(transformed_ast, filename=filename, mode='exec') + local_namespace = {**self.kwargs} + if self.fn.__name__ in local_namespace: + raise ValueError(f"Function name {self.fn.__name__} is reserved") + exec(compiled_code, globals(), local_namespace) + fn = local_namespace[self.fn.__name__].fn + self.rewritted_fn[self.fn] = fn + return fn + @property def __name__(self): return self.fn.__name__ def __getitem__(self, grid): - return GridExecutor(self.fn, self.arg_names, grid) + fn = self._rewrite_ast() + return GridExecutor(fn, self.arg_names, grid) def __call__(self, *args, **kwargs): # This is a device function call _patch_lang(self.fn) + fn = self._rewrite_ast() try: - return self.fn(*args, **kwargs) + return fn(*args, **kwargs) except Exception as e: raise InterpreterError(repr(e)) from e diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 9e7a90cf5090..0e24a5e74d7f 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -847,7 +847,8 @@ def decorator(fn: T) -> JITFunction[T]: assert callable(fn) if os.getenv("TRITON_INTERPRET", "0") == "1": from .interpreter import InterpretedFunction - return InterpretedFunction(fn) + return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, debug=debug, + noinline=noinline, repr=repr, launch_metadata=launch_metadata) else: return JITFunction( fn, @@ -935,3 +936,21 @@ def reinterpret(tensor, dtype): return TensorWrapper(tensor, dtype) else: raise TypeError(f"Cannot reinterpret a {type(tensor)}.") + + +def get_jit_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line