diff --git a/python/test/unit/language/assert_helper.py b/python/test/unit/language/assert_helper.py index d461f3a8e2ae..21bcf5652fb7 100644 --- a/python/test/unit/language/assert_helper.py +++ b/python/test/unit/language/assert_helper.py @@ -14,6 +14,14 @@ def kernel_device_assert(X, Y, BLOCK: tl.constexpr): tl.store(Y + tl.arange(0, BLOCK), x) +@triton.jit +def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # Trivial assert + tl.device_assert(0 == 0, "x != 0") + tl.store(Y + tl.arange(0, BLOCK), x) + + @triton.jit def kernel_assert(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) @@ -34,6 +42,7 @@ def test_assert(func: str): y = torch.zeros(shape, dtype=x.dtype, device="cuda") if func == "device_assert": kernel_device_assert[(1,)](x, y, BLOCK=shape[0]) + kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0]) elif func == "assert": kernel_assert[(1,)](x, y, BLOCK=shape[0]) elif func == "static_assert": diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 5047a866b0ad..678ee8cc780e 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1379,6 +1379,10 @@ def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl. def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor: + cond_ty = cond.type + if not cond_ty.is_block(): + cond_ty = tl.block_type(cond_ty.scalar, (1,)) + cond = tl.tensor(builder.create_splat(cond.handle, (1,)), cond_ty) return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)