diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 802a78f9ed7a..2ce6d4ed1ba7 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -226,7 +226,7 @@ jobs: - name: Install pip dependencies run: | python3 -m pip install --upgrade pip - python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit + python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit - name: Install Triton env: TRITON_BUILD_WITH_CCACHE: "true" @@ -250,8 +250,9 @@ jobs: echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1 fi cd python/test/unit - python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py + python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py python3 -m pytest -s -n 8 language/test_subprocess.py + python3 -m pytest -s -n 8 test_debug.py --forked # Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py # Run hopper/test_flashattention.py separately to avoid out of gpu memory @@ -407,7 +408,10 @@ jobs: pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py cd python/test/unit pytest --capture=tee-sys -rfs -n 16 language runtime \ - --ignore=language/test_line_info.py + --ignore=language/test_line_info.py \ + --ignore=test_debug.py + # TODO: uncomment + # pytest --capture=tee-sys -rfs test_debug.py TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \ pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 648beccf617d..ef19bd762643 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -256,7 +256,7 @@ jobs: - name: Install pip dependencies run: | python3 -m pip install --upgrade pip - python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit + python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit - name: Install Triton env: @@ -284,8 +284,9 @@ jobs: echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1 fi cd python/test/unit - python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py + python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py python3 -m pytest -s -n 8 language/test_subprocess.py + python3 -m pytest -s -n 8 test_debug.py --forked # Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py # Run hopper/test_flashattention.py separately to avoid out of gpu memory @@ -403,7 +404,10 @@ jobs: pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py cd python/test/unit pytest --capture=tee-sys -rfs -n 16 language runtime \ - --ignore=language/test_line_info.py + --ignore=language/test_line_info.py \ + --ignore=test_debug.py + # TODO: uncomment + # pytest --capture=tee-sys -rfs test_debug.py TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \ pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py diff --git a/python/test/unit/language/assert_helper.py b/python/test/unit/language/assert_helper.py deleted file mode 100644 index b80b7d7a0a25..000000000000 --- a/python/test/unit/language/assert_helper.py +++ /dev/null @@ -1,154 +0,0 @@ -import sys - -import torch -from torch.testing import assert_close - -import triton -import triton.language as tl - - -def get_current_target_warp_size(): - return triton.runtime.driver.active.get_current_target().warp_size - - -@triton.jit -def kernel_device_assert(X, Y, BLOCK: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - tl.device_assert(x == 0, "x != 0") - tl.store(Y + tl.arange(0, BLOCK), x) - - -@triton.jit -def kernel_assert_passes(X, Y, BLOCK: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - # Trivial assert, should not be an error. - tl.device_assert(0 == 0, "x != 0") - tl.store(Y + tl.arange(0, BLOCK), x) - - -@triton.jit(debug=False) -def kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - tl.device_assert(x == 0, "x != 0") - tl.store(Y + tl.arange(0, BLOCK), x) - - -@triton.jit -def kernel_assert(X, Y, BLOCK: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - assert x == 0, "x != 0" - tl.store(Y + tl.arange(0, BLOCK), x) - - -@triton.jit -def kernel_static_assert(X, Y, BLOCK: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - tl.static_assert(BLOCK == 128, "BLOCK != 128") - tl.store(Y + tl.arange(0, BLOCK), x) - - -def test_assert(func: str, device: str): - N = 128 # This value should match with test_print in test_subprocess.py. - num_warps = N // get_current_target_warp_size() - - x = torch.arange(0, N, dtype=torch.int32, device=device) - y = torch.zeros((N, ), dtype=x.dtype, device=device) - if func == "device_assert": - kernel_device_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) - if func == "device_assert_passes": - # Assert passes; no error. - kernel_assert_passes[(1, )](x, y, num_warps=num_warps, BLOCK=N) - elif func == "no_debug": - # TRITON_DEBUG=1 can override the debug flag - kernel_device_assert_no_debug[(1, )](x, y, num_warps=num_warps, BLOCK=N) - elif func == "assert": - kernel_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) - elif func == "static_assert": - kernel_static_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) - elif func == "double_assert": - # Launching a different kernel after the first one asserted used to - # segfault. What seems to have happened is: - # - The first kernel is enqueued but doesn't run yet. - # - We go to launch the second kernel. Because this is the first time - # we're running it, we have to load the kernel into the GPU. - # - Loading the kernel takes some time, during which the first launch - # completes. - # - Now the GPU is in an error state. We need to detect this inside - # the kernel-launch/loading code and bail out properly. If we don't, - # we segfault. - kernel_device_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) - kernel_assert_passes[(1, )](x, y, num_warps=num_warps, BLOCK=N) - assert_close(y, x) - # GPU/host synchronization before exiting the test. - getattr(torch, device).synchronize() - - -@triton.jit -def jit_device_assert_none(x): - tl.device_assert(x == 0, "x != 0") - - -@triton.jit(debug=True) -def jit_device_assert_true(x): - tl.device_assert(x == 0, "x != 0") - - -@triton.jit(debug=False) -def jit_device_assert_false(x): - tl.device_assert(x == 0, "x != 0") - - -@triton.jit -def kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - if jit_debug == "true": - jit_device_assert_true(x) - elif jit_debug == "false": - jit_device_assert_false(x) - else: - jit_device_assert_none(x) - tl.store(Y + tl.arange(0, BLOCK), x) - - -@triton.jit(debug=True) -def kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - if jit_debug == "true": - jit_device_assert_true(x) - elif jit_debug == "false": - jit_device_assert_false(x) - else: - jit_device_assert_none(x) - tl.store(Y + tl.arange(0, BLOCK), x) - - -@triton.jit(debug=False) -def kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - if jit_debug == "true": - jit_device_assert_true(x) - elif jit_debug == "false": - jit_device_assert_false(x) - else: - jit_device_assert_none(x) - tl.store(Y + tl.arange(0, BLOCK), x) - - -def test_assert_nested(caller: str, callee: str, device: str): - N = 128 # This value should match with test_print in test_subprocess.py. - num_warps = N // get_current_target_warp_size() - - x = torch.arange(0, N, dtype=torch.int32, device=device) - y = torch.zeros((N, ), dtype=x.dtype, device=device) - if caller == "none": - kernel_device_assert_nested[(1, )](x, y, num_warps=num_warps, BLOCK=N, jit_debug=callee) - elif caller == "true": - kernel_device_assert_nested_true[(1, )](x, y, num_warps=num_warps, BLOCK=N, jit_debug=callee) - elif caller == "false": - kernel_device_assert_nested_false[(1, )](x, y, num_warps=num_warps, BLOCK=N, jit_debug=callee) - assert_close(y, x) - - -if __name__ == "__main__": - fn = globals()[sys.argv[1]] - fn(*sys.argv[2:]) diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index 17118a29bd1f..2ad97e8a6815 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -8,11 +8,6 @@ dir_path = os.path.dirname(os.path.realpath(__file__)) print_path = os.path.join(dir_path, "print_helper.py") -assert_path = os.path.join(dir_path, "assert_helper.py") - -# TODO: bfloat16 after LLVM-15 -assert_types = ["device_assert", "device_assert_passes", "assert", "static_assert", "no_debug", "double_assert"] -nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]] torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] @@ -120,59 +115,3 @@ def test_print(func_type: str, data_type: str, device: str): continue print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') assert all(delta == 0 for delta in diff.values()) - - -@pytest.mark.parametrize("func_type", assert_types) -def test_assert(func_type: str, device: str): - # The total number of elements in the 1-D tensor to assert on. - N = 128 - - proc = subprocess.run( - [sys.executable, assert_path, "test_assert", func_type, device], - capture_output=True, - env={**os.environ, "TRITON_DEBUG": "1"}, - ) - errs = proc.stderr.splitlines() - num_errs = 0 - for err in errs: - if "x != 0" in err.decode("utf-8", errors="ignore"): - num_errs += 1 - - # Check for segfaults. - assert all("segmentation fault" not in line.decode("utf-8", errors="ignore").lower() for line in errs) - - if func_type == "static_assert" or func_type == "device_assert_passes": - assert num_errs == 0 - else: - assert num_errs == N - 1 - - -@pytest.mark.parametrize("caller_type, callee_type", nested_types) -def test_assert_nested(caller_type, callee_type, device): - # The total number of elements in the 1-D tensor to assert on. - N = 128 - - proc = subprocess.run( - [sys.executable, assert_path, "test_assert_nested", caller_type, callee_type, device], - capture_output=True, - ) - errs = proc.stderr.splitlines() - num_errs = 0 - for err in errs: - if "x != 0" in err.decode("utf-8", errors="ignore"): - num_errs += 1 - if caller_type == "none": - if callee_type == "true": - assert num_errs == N - 1 - else: - assert num_errs == 0 - elif caller_type == "true": - if callee_type == "false": - assert num_errs == 0 - else: - assert num_errs == N - 1 - elif caller_type == "false": - if callee_type == "true": - assert num_errs == N - 1 - else: - assert num_errs == 0 diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 8a3403bc1df1..48e4eeebd33a 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -427,23 +427,18 @@ def kernel_add(a, b, o, N: tl.constexpr): def test_jit_debug(device) -> None: @triton.jit - def kernel_add(a, b, o, N: tl.constexpr): - idx = tl.arange(0, N) - tl.device_assert(idx < 32, "idx < 32") - tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + def kernel(tmp): + tl.device_assert(tl.load(tmp) == 1, "tmp == 1") device = getattr(torch, device).current_device() - assert len(kernel_add.cache[device]) == 0 - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) - assert len(kernel_add.cache[device]) == 1 - kernel_add.debug = False - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) - assert len(kernel_add.cache[device]) == 2 - kernel_add.debug = True - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) - assert len(kernel_add.cache[device]) == 3 - bins = list(kernel_add.cache[device].values()) - assert bins[2].asm['ttir'] != bins[1].asm['ttir'] + tmp = torch.tensor([1], dtype=torch.int32, device="cuda") + assert len(kernel.cache[device]) == 0 + kernel[(1, )](tmp, debug=False) + assert len(kernel.cache[device]) == 1 + kernel[(1, )](tmp, debug=True) + assert len(kernel.cache[device]) == 2 + bins = list(kernel.cache[device].values()) + assert bins[0].asm['ttir'] != bins[1].asm['ttir'] @triton.jit diff --git a/python/test/unit/test_debug.py b/python/test/unit/test_debug.py new file mode 100644 index 000000000000..d37062396687 --- /dev/null +++ b/python/test/unit/test_debug.py @@ -0,0 +1,124 @@ +import os +import pytest +import torch +import triton.language as tl +import triton + +@pytest.mark.parametrize('cond, opt_flag, env_var', [ + (cond, opt_flag, env_var) for cond in [True, False] \ + for opt_flag in [True, False] \ + for env_var in [True, False]\ +]) +@pytest.mark.forked +def test_device_assert(cond, opt_flag, env_var, device="cuda"): + os.environ['TRITON_DEBUG'] = str(int(env_var)) + torch.zeros([1], dtype=torch.int32, device=device) + + @triton.jit + def _kernel(COND: tl.constexpr): + tl.device_assert(COND, 'test') + + if not cond and (opt_flag or env_var): + with pytest.raises(RuntimeError): + _kernel[(1, )](cond, debug=opt_flag) + torch.cuda.synchronize() + return + + _kernel[(1, )](cond, debug=opt_flag) + torch.cuda.synchronize() + + +@pytest.mark.parametrize("cond", [False, True]) +def test_static_assert(cond): + + @triton.jit + def _kernel(COND: tl.constexpr): + tl.static_assert(COND) + + if not cond: + with pytest.raises(triton.compiler.errors.CompileTimeAssertionFailure): + _kernel[(1, )](cond) + return + + _kernel[(1, )](cond) + + +def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref_func): + device = "cuda" + x = torch.tensor([x], dtype=getattr(torch, x_dtype), device=device) + y = torch.tensor([y], dtype=getattr(torch, y_dtype), device=device) + z = torch.empty_like(x) + if should_overflow and debug: + with pytest.raises(RuntimeError) as exc_info: + tri_func[(1, )](x, y, z, debug=debug) + torch.cuda.synchronize() + assert "device-side assert" in str(exc_info.value) + else: + tri_func[(1, )](x, y, z, debug=debug) + torch.cuda.synchronize() + assert int(z) == int(ref_func(x, y)) + + +# integer overflow sanitization + + +@pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ + (-2**31, -1, 'int32', 'int32', False, False), + (-2**31, -1, 'int32', 'int32', True, True), + (2**31 - 1, 1, 'int32', 'int32', True, True), + (2**31 - 1, 100, 'int32', 'int32', True, True), + (-2**31, 0, 'int32', 'int32', True, False), + (-2**31, 2, 'int32', 'int32', True, False), + (0, -1, 'int32', 'int32', True, False), + (-2**15, -1, 'int16', 'int16', True, True), + (2**15 - 1, 1, 'int16', 'int16', True, True), +]) +@pytest.mark.forked +def test_sanitize_int_add_overflow(x, y, x_dtype, y_dtype, debug, should_overflow): + + @triton.jit + def _kernel_add(X, Y, Z): + tl.store(Z, tl.load(X) + tl.load(Y)) + + _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_add, lambda x, y: x + y) + + +# mul overflow + + +@pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ + (2**30, 4, 'int32', 'int32', False, False), + (2**30, 4, 'int32', 'int32', True, True), + (2**30, 2, 'int32', 'int32', True, True), + (-2**30, -4, 'int32', 'int32', True, True), + (-2**31, 1, 'int32', 'int32', True, False), + (-2**30, 2, 'int32', 'int32', True, False), +]) +@pytest.mark.forked +def test_sanitize_int_mul_overflow(x, y, x_dtype, y_dtype, debug, should_overflow): + + @triton.jit + def _kernel_mul(X, Y, Z): + tl.store(Z, tl.load(X) * tl.load(Y)) + + _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_mul, lambda x, y: x * y) + + +# sub overflow + + +@pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ + (-2**31, 1, 'int32', 'int32', False, False), + (-2**31, 1, 'int32', 'int32', True, True), + (2**31 - 1, -1, 'int32', 'int32', True, True), + (2**31 - 1, 1, 'int32', 'int32', True, False), + (-2**31, -1, 'int32', 'int32', True, False), +]) +@pytest.mark.forked +def test_sanitize_int_sub_overflow(x, y, x_dtype, y_dtype, debug, should_overflow): + + @triton.jit + def _kernel_sub(X, Y, Z): + tl.store(Z, tl.load(X) - tl.load(Y)) + + _test_overflow(x, y, x_dtype, y_dtype, should_overflow, debug, _kernel_sub, lambda x, y: x - y) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 3cb70b1b7e97..5f13dd9a7b9f 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -192,8 +192,8 @@ def visit_Call(self, node: ast.Call) -> bool: class CodeGenerator(ast.NodeVisitor): def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, - codegen_fns, module_map, debug=None, module=None, is_kernel=False, - function_types: Optional[Dict] = None, noinline=False, file_name: Optional[str] = None, begin_line=0): + codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): self.context = context self.builder = ir.builder(context) self.file_name = file_name @@ -229,7 +229,6 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.function_name = function_name self.is_kernel = is_kernel self.cur_node = None - self.debug = options.debug if debug is None else debug self.noinline = noinline self.scf_stack = [] self.ret_type = None @@ -1048,11 +1047,8 @@ def visit_keyword(self, node) -> Tuple[str, Any]: return node.arg, self.visit(node.value) def visit_Assert(self, node) -> Any: - if not self.debug: - return test = self.visit(node.test) msg = self.visit(node.msg) if node.msg is not None else "" - # Convert assert to triton's device_assert which happens on the device return language.core.device_assert(test, msg, _builder=self.builder) def call_JitFunction(self, fn: JITFunction, args, kwargs): @@ -1074,12 +1070,11 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): gscope = fn.__globals__ # If the callee is not set, we use the same debug setting as the caller file_name, begin_line = get_jit_fn_file_line(fn) - debug = self.debug if fn.debug is None else fn.debug generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, noinline=fn.noinline, file_name=file_name, begin_line=begin_line, options=self.builder.options, codegen_fns=self.builder.codegen_fns, - module_map=self.builder.module_map, debug=debug) + module_map=self.builder.module_map) try: generator.visit(fn.parse()) except Exception as e: @@ -1111,9 +1106,6 @@ def visit_Call(self, node): kws = dict(self.visit(keyword) for keyword in node.keywords) args = [self.visit(arg) for arg in node.args] - # TODO: this should not be so hardcoded - if fn is language.core.device_assert and not self.debug: - return if isinstance(fn, JITFunction): _check_fn_args(node, fn, args) return self.call_JitFunction(fn, args, kws) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 0a84bd86a5a1..9f100c0a97e0 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -28,6 +28,7 @@ TRITON_MAX_TENSOR_NUMEL, _experimental_descriptor_load, _experimental_descriptor_store, + add, advance, arange, associative_scan, @@ -125,6 +126,7 @@ "_experimental_descriptor_load", "_experimental_descriptor_store", "abs", + "add", "advance", "arange", "argmax", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index b229ca183524..b73a2f08bb6b 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -30,6 +30,7 @@ def builtin(fn: T) -> T: @wraps(fn) def wrapper(*args, **kwargs): if "_builder" not in kwargs or kwargs["_builder"] is None: + print(kwargs) raise ValueError("Did you forget to add @triton.jit ? " "(`_builder` argument must be provided outside of JIT functions.)") return fn(*args, **kwargs) @@ -440,6 +441,20 @@ def kind(self): assert self.is_floating() return dtype.KIND.FLOATING + def get_int_max_value(self): + if self.is_int_signed(): + return 2**(self.int_bitwidth - 1) - 1 + if self.is_int_unsigned(): + return 2**self.int_bitwidth - 1 + assert False + + def get_int_min_value(self): + if self.is_int_signed(): + return -2**(self.int_bitwidth - 1) + if self.is_int_unsigned(): + return 0 + assert False + @staticmethod def is_dtype(type_str): return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES @@ -755,31 +770,27 @@ def __str__(self) -> str: @builtin def __add__(self, other, _builder=None): - other = _unwrap_if_constexpr(other) - return semantic.add(self, other, _builder) + return add(self, other, sanitize_overflow=True, _builder=_builder) @builtin def __radd__(self, other, _builder=None): - return self.__add__(other, _builder=_builder) + return add(other, self, sanitize_overflow=True, _builder=_builder) @builtin def __sub__(self, other, _builder=None): - other = _unwrap_if_constexpr(other) - return semantic.sub(self, other, _builder) + return sub(self, other, sanitize_overflow=True, _builder=_builder) @builtin def __rsub__(self, other, _builder=None): - other = _unwrap_if_constexpr(other) - return semantic.sub(other, self, _builder) + return sub(other, self, sanitize_overflow=True, _builder=_builder) @builtin def __mul__(self, other, _builder=None): - other = _unwrap_if_constexpr(other) - return semantic.mul(self, other, _builder) + return mul(self, other, sanitize_overflow=True, _builder=_builder) @builtin def __rmul__(self, other, _builder=None): - return self.__mul__(other, _builder=_builder) + return mul(other, self, sanitize_overflow=True, _builder=_builder) @builtin def __truediv__(self, other, _builder=None): @@ -1871,6 +1882,33 @@ def where(condition, x, y, _builder=None): # ----------------------- +@builtin +def add(x, y, sanitize_overflow: constexpr = True, _builder=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + return semantic.add(x, y, sanitize_overflow, _builder) + + +@builtin +def sub(x, y, sanitize_overflow: constexpr = True, _builder=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + return semantic.sub(x, y, sanitize_overflow, _builder) + + +@builtin +def mul(x, y, sanitize_overflow: constexpr = True, _builder=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + return semantic.mul(x, y, sanitize_overflow, _builder) + + @builtin def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): """ diff --git a/python/triton/language/extra/cuda/_experimental_tma.py b/python/triton/language/extra/cuda/_experimental_tma.py index 949071ac9fb1..5677810194d9 100644 --- a/python/triton/language/extra/cuda/_experimental_tma.py +++ b/python/triton/language/extra/cuda/_experimental_tma.py @@ -67,7 +67,7 @@ def experimental_device_tensormap_create2d( element_size = element_ty.primitive_bitwidth // 8 element_size_t = core.full([], element_size, core.int64, _builder=_builder) - global_stride = semantic.mul(element_size_t, global_size[-1], _builder) + global_stride = semantic.mul(element_size_t, global_size[-1], True, _builder) # Undocumented, but global_stride seems to be divided by 16 global_stride = semantic.ashr(global_stride, semantic.to_tensor(4, _builder), _builder) diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 430aeb09e2e8..1c001695e189 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -34,11 +34,11 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL _c0, _c2 = c0, c2 c0 = math.umulhi(B, _c2) ^ c1 ^ k0 c2 = math.umulhi(A, _c0) ^ c3 ^ k1 - c1 = B * _c2 - c3 = A * _c0 + c1 = tl.mul(B, _c2, sanitize_overflow=False) + c3 = tl.mul(A, _c0, sanitize_overflow=False) # raise key - k0 = k0 + PHILOX_KEY_A - k1 = k1 + PHILOX_KEY_B + k0 = tl.add(k0, PHILOX_KEY_A, sanitize_overflow=False) + k1 = tl.add(k1, PHILOX_KEY_B, sanitize_overflow=False) return c0, c1, c2, c3 diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 295c0302df3d..44a8dce0d01a 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -196,7 +196,27 @@ def binary_op_type_checking_impl(lhs: tl.tensor | numbers.Number, rhs: tl.tensor return lhs, rhs -def add(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor: +def binary_op_sanitize_overflow_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, binary_op: callable): + if lhs.type.scalar.int_bitwidth >= 64 or not builder.options.sanitize_overflow: + return + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + assert lhs_sca_ty == rhs_sca_ty + assert lhs_sca_ty.is_int() + lhs = cast(lhs, tl.int64, builder) + rhs = cast(rhs, tl.int64, builder) + ret = binary_op(lhs, rhs, False, builder) + max_value = lhs_sca_ty.get_int_max_value() + max_value = tl.tensor(builder.get_int64(max_value), tl.int64) + min_value = lhs_sca_ty.get_int_min_value() + min_value = tl.tensor(builder.get_int64(min_value), tl.int64) + cond = and_(less_equal(ret, max_value, builder), greater_equal(ret, min_value, builder), builder) + msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}" + device_assert(cond, msg, builder) + + +def add(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool, + builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar @@ -216,11 +236,14 @@ def add(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) # int + int elif input_scalar_ty.is_int(): + if sanitize_overflow: + binary_op_sanitize_overflow_impl(input, other, builder, add) return tl.tensor(builder.create_add(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {input_scalar_ty}") -def sub(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor: +def sub(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool, + builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, True, False) scalar_ty = input.type.scalar # ptr - offset @@ -231,18 +254,23 @@ def sub(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) # int - int elif scalar_ty.is_int(): + if sanitize_overflow: + binary_op_sanitize_overflow_impl(input, other, builder, sub) return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {scalar_ty}") -def mul(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor: +def mul(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool, + builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float * float if scalar_ty.is_floating(): return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) - # * int + # int * int elif scalar_ty.is_int(): + if sanitize_overflow: + binary_op_sanitize_overflow_impl(input, other, builder, mul) return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {scalar_ty}") @@ -306,7 +334,8 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu # float % float if scalar_ty.is_floating(): # input - input.div(other, rounding_mode="floor") * other - ret = sub(input, mul(math.floor(fdiv(input, other, False, builder), _builder=builder), other, builder), builder) + floor = math.floor(fdiv(input, other, False, builder), _builder=builder) + ret = sub(input, mul(floor, other, True, builder), True, builder) return ret # % int elif scalar_ty.is_int(): @@ -460,7 +489,7 @@ def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: if input_sca_ty.is_ptr(): raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) - return sub(_0, input, builder) + return sub(_0, input, True, builder) def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: @@ -1632,6 +1661,8 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor: + if not builder.options.debug: + return cond_ty = cond.type if not cond_ty.is_block(): cond_ty = tl.block_type(cond_ty.scalar, (1, )) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 8c8a09b67dff..0aeaff73a4ea 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -75,6 +75,7 @@ def materialize_pointers(self, boundary_check): class InterpreterOptions: extern_libs: dict = None debug: bool = False + sanitize_overflow: bool = True arch: str = None supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15") deprecated_fp8_dtypes: Tuple[str] = () diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 552c1eadf156..0842849ad982 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -561,13 +561,14 @@ def create_binder(self, backend): ] def run(self, *args, grid, warmup, **kwargs): + kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1" + # parse options from ..compiler import make_backend device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) target = driver.active.get_current_target() backend = make_backend(target) - kwargs["debug"] = self.debug # Execute pre run hooks with args and kwargs for hook in self.pre_run_hooks: @@ -697,7 +698,6 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ self.kernel = None - self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug self.noinline = noinline # TODO(jlebar): Remove uses of these fields outside this file, then diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 6450d582457e..ea6785c0942f 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -34,6 +34,7 @@ class HIPOptions: extern_libs: dict = None cluster_dims: tuple = (1, 1, 1) debug: bool = False + sanitize_overflow: bool = True arch: str = None supported_fp8_dtypes: Tuple[str] = ("fp8e5", ) deprecated_fp8_dtypes: Tuple[str] = () diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index adfde57b01b0..84674c542679 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -102,6 +102,7 @@ class CUDAOptions: extern_libs: dict = None debug: bool = False backend_name: str = 'cuda' + sanitize_overflow: bool = True def __post_init__(self): default_libdir = Path(__file__).parent / 'lib'