From ffada3f03cfa032fa4aae95640ee3d9df03d5003 Mon Sep 17 00:00:00 2001 From: ZhangLirong-amd Date: Wed, 5 Nov 2025 01:17:25 +0000 Subject: [PATCH] fix graph_breaks by return tensor for bool op --- aiter/jit/utils/torch_guard.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/aiter/jit/utils/torch_guard.py b/aiter/jit/utils/torch_guard.py index 99024692f3..cb969c9bce 100644 --- a/aiter/jit/utils/torch_guard.py +++ b/aiter/jit/utils/torch_guard.py @@ -270,18 +270,18 @@ def wrapper_register(calling_func): else: new_input = "(Tensor dummy, " + input_part[1:] - return_int = False + return_non_tensor = False return_annotation = sig.return_annotation - if return_annotation is int: + if return_annotation in [int, bool, float]: output_part = "(Tensor, " + output_part + ")" - return_int = True + return_non_tensor = True schema = f"{new_input} -> {output_part}".strip() loadName = calling_func.__name__ def abstract_impl(*args, custom_build_args={}, **kwargs): - if return_int: + if return_non_tensor: return torch.empty(1, device=device), 1 if gen_fake is not None: return gen_fake(*args, **kwargs) @@ -290,12 +290,12 @@ def abstract_impl(*args, custom_build_args={}, **kwargs): def outer_wrapper(*args, **kwargs): return ( wrapper(*args, **kwargs) - if not return_int + if not return_non_tensor else (torch.empty(1, device=device), wrapper(*args, **kwargs)) ) def abstract_impl_dummy(dummy, *args, custom_build_args={}, **kwargs): - if return_int: + if return_non_tensor: return torch.empty(1, device=device), 1 if gen_fake is not None: return gen_fake(*args, **kwargs) @@ -304,7 +304,7 @@ def abstract_impl_dummy(dummy, *args, custom_build_args={}, **kwargs): def outer_wrapper_dummy(dummy, *args, **kwargs): return ( wrapper(*args, **kwargs) - if not return_int + if not return_non_tensor else (torch.empty(1, device=device), wrapper(*args, **kwargs)) ) @@ -333,7 +333,7 @@ def wrapper_custom(*args, custom_build_args={}, **kwargs): torch.empty(1, device=device), *args, **kwargs ) ) - return result[1] if return_int else result + return result[1] if return_non_tensor else result return wrapper_custom