diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f1275d29065e..293acd4948cc 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2381,26 +2381,32 @@ def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): # ------------- -@pytest.mark.parametrize("if_type", ["if", "if_exp"]) +@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and"]) def test_if(if_type): @triton.jit - def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr): + def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr): pid = tl.program_id(0) cond = tl.load(Cond) if IfType == "if": - if pid % 2: + if pid % 2 == 0: tl.store(Ret, tl.load(XTrue)) else: tl.store(Ret, tl.load(XFalse)) - else: + elif IfType == "if_exp": tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and": + if BoolVar and pid % 2 == 0: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) cond = torch.ones(1, dtype=torch.int32, device='cuda') x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda') x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') ret = torch.empty(1, dtype=torch.float32, device='cuda') - kernel[(1,)](cond, x_true, x_false, ret, if_type) + kernel[(1,)](cond, x_true, x_false, ret, if_type, True) + assert torch.equal(ret, x_true) def test_num_warps_pow2(): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index b4a55c0fa9a8..968e57f59385 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -656,7 +656,7 @@ def bitcast(input: tl.tensor, src_bits = src_sca_ty.primitive_bitwidth dst_bits = dst_sca_ty.primitive_bitwidth if src_bits != dst_bits: - raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + "to " + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " "data-type of size " + str(dst_bits)) return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 843f022d67a8..102a6d1ba82d 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -176,7 +176,7 @@ def is_divisible_by_16(x): return True return False divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize} - equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize} + equal_to_1 = {i for i, arg in enumerate(args) if not isinstance(arg, bool) and isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize} return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1)) # return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)