diff --git a/python/test/regression/test_functional_regressions.py b/python/test/regression/test_functional_regressions.py index 82298c41c710..b6143b17871f 100644 --- a/python/test/regression/test_functional_regressions.py +++ b/python/test/regression/test_functional_regressions.py @@ -239,3 +239,40 @@ def kernel(in_ptr, out_ptr): kernel[(1, )](data, res) ref = torch.flip(data[1:513], [0]) assert (res == ref).all() + + +@triton.jit +def _triton_cummax_helper_fn(arg0_0, arg0_1, arg1_0, arg1_1): + tmp0 = arg0_0 > arg1_0 + tmp1 = arg0_0 == arg1_0 + tmp2 = arg0_1 > arg1_1 + tmp3 = tmp1 & tmp2 + tmp4 = tmp0 | tmp3 + tmp5 = tl.where(tmp4, arg0_0, arg1_0) + tmp6 = tl.where(tmp4, arg0_1, arg1_1) + return tmp5, tmp6 + + +def test_inductor_cummax_bool(device): + + @triton.jit + def triton_(in_ptr0, out_ptr0, out_ptr1, XBLOCK: tl.constexpr): + offset = tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + offset).to(tl.int1) + tmp1 = tmp0.to(tl.int1) + tmp3 = offset.to(tl.int64) + tmp5, tmp6, = tl.associative_scan(( + tmp1, + tmp3, + ), 0, _triton_cummax_helper_fn) + tl.store(out_ptr0 + offset, tmp5) + tl.store(out_ptr1 + offset, tmp6) + + a = torch.randn((64, ), device=device) > 0 + values = torch.empty((64, ), dtype=torch.bool, device=device) + indices = torch.empty((64, ), dtype=torch.int64, device=device) + ref = torch.cummax(a, dim=0) + + triton_[(1, )](a, values, indices, 64) + torch.testing.assert_close(ref.values, values) + torch.testing.assert_close(ref.indices, indices)