diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b107fe5a7a94..d0f56d8c12aa 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2322,6 +2322,33 @@ def kernel(ExitEarly, Out): assert to_numpy(out)[0] == 1 +@triton.jit +def add_fn(x): + return x + 1 + + +@pytest.mark.parametrize("call_type", ["attribute", "jit_function"]) +def test_if_call(call_type): + @triton.jit + def kernel(Out, call_type: tl.constexpr): + pid = tl.program_id(0) + o = tl.load(Out) + if pid == 0: + if call_type == "attribute": + a = o + 1 + a = a.to(tl.int32) + o = a + else: + a = o + a = add_fn(a) + o = a + tl.store(Out, o) + + out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda') + kernel[(1,)](out, call_type) + assert to_numpy(out)[0] == 1 + + @pytest.mark.parametrize("_cond1", [True, False]) @pytest.mark.parametrize("_cond2", [True, False]) @pytest.mark.parametrize("_cond3", [True, False]) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index ad6ecba6198d..7d0a28cd1f47 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -168,6 +168,8 @@ def contains_return_op(self, node): pred = lambda s: self.contains_return_op(s) return any(pred(s) for s in node.body) elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Attribute): + return False fn = self.visit(node.func) if isinstance(fn, JITFunction): old_gscope = self.gscope