diff --git a/aiter/dist/parallel_state.py b/aiter/dist/parallel_state.py index 8b123ea11c..0f336ba1f1 100644 --- a/aiter/dist/parallel_state.py +++ b/aiter/dist/parallel_state.py @@ -107,8 +107,9 @@ def all_reduce_fake( return torch.empty_like(tensor) +# There is same name all_reduce in aiter.op, use Alias @torch_compile_guard(gen_fake=all_reduce_fake) -def all_reduce( +def all_reduce_( tensor: torch.Tensor, group_name: str, ca_fp8_quant: bool ) -> torch.Tensor: assert group_name in _groups, f"Group {group_name} is not found." @@ -317,7 +318,7 @@ def all_reduce( if self.world_size == 1: return input_ - return all_reduce( + return all_reduce_( input_, group_name=self.unique_name, ca_fp8_quant=ca_fp8_quant ) diff --git a/aiter/jit/core.py b/aiter/jit/core.py index 3aa0cd74e5..ef9776e680 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -14,7 +14,7 @@ import traceback import types import typing -from typing import Any, Callable, List, Optional, Union, get_args, get_origin +from typing import Any, Callable, List, Optional from packaging.version import Version, parse @@ -23,7 +23,7 @@ from chip_info import get_gfx from cpp_extension import _jit_compile, get_hip_version from file_baton import FileBaton -from torch_guard import is_torch_equal_or_newer, torch_compile_guard # noqa: E402 +from torch_guard import torch_compile_guard # noqa: E402 AITER_REBUILD = int(os.environ.get("AITER_REBUILD", "0")) @@ -664,160 +664,6 @@ def convert(d_ops: dict): ) -MANUAL_SCHEMA_OPS = [ - "register_graph_buffers", - "module_moe_ck2stages", - "mha_fwd", - "fmha_v3_fwd", - "mha_varlen_fwd", - "mha_bwd", - "fmha_v3_bwd", - "mha_varlen_bwd", - "fmha_v3_varlen_bwd", - "fmha_v3_varlen_fwd", - "mha_batch_prefill", - "hipb_findallsols", - "rocb_findallsols", - "_ActivationType", - "_QuantType", - "init_custom_ar", - "greedy_sample", - "random_sample_outer_exponential", - "random_sample", - "mixed_sample", - "exponential", -] - -NONE_WRAPPED_OP = [ - # "hipb_create_extension", - # "hipb_destroy_extension", - "getHipblasltKernelName", - # "rocb_create_extension", - # "rocb_destroy_extension", - "get_graph_buffer_ipc_meta", - "_ActivationType", - "_QuantType", - # "dispose", - # "meta_size", - # "get_padded_m", - "compile_mha_fwd", - "compile_mha_bwd", - "init_custom_qr", - "qr_max_size", - "qr_destroy", - "qr_open_handles", - "qr_get_handle", -] - -# We default all args are inplace, you can define inplace args for specific op -SPECIAL_OPS_MUTATES_ARGS = {} - - -def generate_schema(func) -> str: - import inspect - - import torch - - sig = inspect.signature(func) - parameters = [] - mutates_args = SPECIAL_OPS_MUTATES_ARGS.get(func.__name__, []) - for idx, (name, param) in enumerate(sig.parameters.items()): - param_type = param.annotation - flag = True - is_mutates = True - if len(mutates_args) > 0 and name not in mutates_args: - is_mutates = False - - if param_type is torch.Tensor: - if is_mutates: - type_str = f"Tensor(a{idx}!)" - else: - type_str = "Tensor" - elif param_type == Optional[torch.Tensor]: - if is_mutates: - type_str = f"Tensor(a{idx}!)?" - else: - type_str = "Tensor?" - elif get_origin(param_type) is Union and torch.Tensor in get_args(param_type): - if is_mutates: - type_str = f"Tensor(a{idx}!)?" - else: - type_str = "Tensor?" - elif param_type in (torch.SymInt, int): - type_str = "SymInt" - elif param_type in (float, bool, str): - type_str = param_type.__name__ - elif param_type == Optional[torch.Generator]: - type_str = "Generator?" - elif ( - get_origin(param_type) in (list, List) - and get_args(param_type)[0] is torch.Tensor - ): - if is_mutates: - type_str = f"Tensor(a{idx}!)[]" - else: - type_str = "Tensor[]" - elif get_origin(param_type) in (list, List) and get_args(param_type)[0] is int: - type_str = "int[]" - elif param_type == Optional[torch.dtype]: - type_str = "ScalarType?" - else: - type_str = "*" - flag = False - if flag: - param_str = f"{type_str} {name}" - - if param.default != inspect.Parameter.empty: - if param.default is None: - param_str += "=None" - else: - param_str += f"={param.default}" - else: - param_str = f"{type_str} " - - parameters.append(param_str) - return_annotation = sig.return_annotation - return_type = "" - if return_annotation is type(None) or return_annotation is None: - return_type = "()" - elif return_annotation is torch.Tensor: - return_type = "Tensor" - elif ( - get_origin(return_annotation) is list and get_args(return_annotation)[0] is int - ): - return_type = "int[]" - elif return_annotation is int: - return_type = "int" - elif return_annotation is float: - return_type = "float" - elif return_annotation is bool: - return_type = "bool" - elif ( - get_origin(return_annotation) is list - and get_args(return_annotation)[0] is torch.Tensor - ): - return_type = "Tensor[]" - elif get_origin(return_annotation) is tuple: - args = get_args(return_annotation) - type_strings = [] - for arg in args: - if arg is torch.Tensor: - type_strings.append("Tensor") - elif arg is int: - type_strings.append("int") - elif arg is float: - type_strings.append("float") - elif arg is bool: - type_strings.append("bool") - return_type = f"({', '.join(type_strings)})" - else: - return_type = "Any" - - schema = f"({', '.join(parameters)}) -> {return_type}" - - return schema - - def compile_ops( _md_name: str, fc_name: Optional[str] = None, @@ -986,138 +832,10 @@ def check_args(): return op(*args, **kwargs) - if func.__name__ in NONE_WRAPPED_OP: - return wrapper - - def wrapper_register(func): - import inspect - - import torch - import torch.library - from torch.library import Library - - global aiter_lib - aiter_lib = Library("aiter", "FRAGMENT") if aiter_lib is None else aiter_lib - schema = "" - if func.__name__ in MANUAL_SCHEMA_OPS: - schema = generate_schema(func) - else: - sig = inspect.signature(func) - mutates_args = SPECIAL_OPS_MUTATES_ARGS.get(func.__name__, "unknown") - if hasattr(torch.library, "infer_schema"): - sig = torch.library.infer_schema(func, mutates_args=mutates_args) - else: - # for pytorch 2.4 - import torch._custom_op.impl - - # torch 2.4 not support mutates "unknown" for inplace all param - if mutates_args == "unknown": - mutates_args = [] - - for param_name, param in sig.parameters.items(): - if param.annotation == torch.Tensor: - mutates_args.append(param_name) - - sig = torch._custom_op.impl.infer_schema(func, mutates_args) - schema = f"{sig}" - return schema - - schema = wrapper_register(func) - - import inspect - - import torch - - sig = inspect.signature(func) - input_is_tensor = False - parameters = list(sig.parameters.values()) - - if parameters: - first_param = parameters[0] - if ( - first_param.annotation is not inspect.Parameter.empty - and first_param.annotation is torch.Tensor - ): - input_is_tensor = True - - input_part, output_part = schema.split("->", 1) - if input_is_tensor: - new_input = input_part - else: - if not sig.parameters: - new_input = "(Tensor dummy)" - else: - new_input = "(Tensor dummy, " + input_part[1:] - - return_int = False - return_annotation = sig.return_annotation - if return_annotation is int: - output_part = "(Tensor, " + output_part + ")" - return_int = True - - schema = f"{new_input} -> {output_part}".strip() - - loadName = func.__name__ - - def abstract_impl(*args, custom_build_args={}, **kwargs): - if return_int: - return torch.empty(1, device="cuda"), 1 - if gen_fake is not None: - return gen_fake(*args, **kwargs) - return func(*args, **kwargs) - - def outer_wrapper(*args, **kwargs): - return ( - wrapper(*args, **kwargs) - if not return_int - else (torch.empty(1, device="cuda"), wrapper(*args, **kwargs)) - ) - - def abstract_impl_dummy(dummy, *args, custom_build_args={}, **kwargs): - if return_int: - return torch.empty(1, device="cuda"), 1 - if gen_fake is not None: - return gen_fake(*args, **kwargs) - return func(*args, **kwargs) - - def outer_wrapper_dummy(dummy, *args, **kwargs): - return ( - wrapper(*args, **kwargs) - if not return_int - else (torch.empty(1, device="cuda"), wrapper(*args, **kwargs)) - ) - - custom_func = outer_wrapper - fake_func = abstract_impl - if not input_is_tensor: - custom_func = outer_wrapper_dummy - fake_func = abstract_impl_dummy - - if not hasattr(torch.ops.aiter, f"wrapper_{loadName}"): - if is_torch_equal_or_newer("2.8.0"): - tags = () - else: - tags = (torch.Tag.needs_fixed_stride_order,) - op_schema = f"aiter::wrapper_{loadName}" + schema - aiter_lib.define(op_schema, tags=tags) - aiter_lib.impl( - f"aiter::wrapper_{loadName}", custom_func, dispatch_key="CUDA" - ) - aiter_lib.impl( - f"aiter::wrapper_{loadName}", custom_func, dispatch_key="CPU" - ) - aiter_lib._register_fake(f"wrapper_{loadName}", fake_func) - - def wrapper_custom(*args, custom_build_args={}, **kwargs): - result = ( - getattr(torch.ops.aiter, f"wrapper_{loadName}")(*args, **kwargs) - if input_is_tensor - else getattr(torch.ops.aiter, f"wrapper_{loadName}")( - torch.empty(1, device="cuda"), *args, **kwargs - ) - ) - return result[1] if return_int else result + @torch_compile_guard(device="cuda", gen_fake=gen_fake, calling_func_=func) + def custom_wrapper(*args, **kwargs): + return wrapper(*args, **kwargs) - return wrapper_custom + return custom_wrapper return decorator diff --git a/aiter/jit/utils/torch_guard.py b/aiter/jit/utils/torch_guard.py index b1a12fb0b3..99024692f3 100644 --- a/aiter/jit/utils/torch_guard.py +++ b/aiter/jit/utils/torch_guard.py @@ -3,7 +3,7 @@ from packaging import version from packaging.version import Version import importlib -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union, List, get_args, get_origin aiter_lib = None @@ -33,95 +33,308 @@ def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: return torch_version >= version.parse(target) +MANUAL_SCHEMA_OPS = [ + "register_graph_buffers", + "module_moe_ck2stages", + "mha_fwd", + "fmha_v3_fwd", + "mha_varlen_fwd", + "mha_bwd", + "fmha_v3_bwd", + "mha_varlen_bwd", + "fmha_v3_varlen_bwd", + "fmha_v3_varlen_fwd", + "mha_batch_prefill", + "hipb_findallsols", + "rocb_findallsols", + "_ActivationType", + "_QuantType", + "init_custom_ar", + "greedy_sample", + "random_sample", + "mixed_sample", + "exponential", +] + + +NONE_WRAPPED_OP = [ + # "hipb_create_extension", + # "hipb_destroy_extension", + "getHipblasltKernelName", + # "rocb_create_extension", + # "rocb_destroy_extension", + "get_graph_buffer_ipc_meta", + "_ActivationType", + "_QuantType", + # "dispose", + # "meta_size", + # "get_padded_m", + "compile_mha_fwd", + "compile_mha_bwd", + "init_custom_qr", + "qr_max_size", + "qr_destroy", + "qr_open_handles", + "qr_get_handle", +] + +# We default all args are inplace, you can define inplace args for specific op +SPECIAL_OPS_MUTATES_ARGS = {} + + +def generate_schema(func) -> str: + import inspect + + import torch + + sig = inspect.signature(func) + parameters = [] + mutates_args = SPECIAL_OPS_MUTATES_ARGS.get(func.__name__, []) + for idx, (name, param) in enumerate(sig.parameters.items()): + param_type = param.annotation + flag = True + is_mutates = True + if len(mutates_args) > 0 and name not in mutates_args: + is_mutates = False + + if param_type is torch.Tensor: + if is_mutates: + type_str = f"Tensor(a{idx}!)" + else: + type_str = "Tensor" + elif param_type == Optional[torch.Tensor]: + if is_mutates: + type_str = f"Tensor(a{idx}!)?" + else: + type_str = "Tensor?" + elif get_origin(param_type) is Union and torch.Tensor in get_args(param_type): + if is_mutates: + type_str = f"Tensor(a{idx}!)?" + else: + type_str = "Tensor?" + elif param_type in (torch.SymInt, int): + type_str = "SymInt" + elif param_type in (float, bool, str): + type_str = param_type.__name__ + elif param_type == Optional[torch.Generator]: + type_str = "Generator?" + elif ( + get_origin(param_type) in (list, List) + and get_args(param_type)[0] is torch.Tensor + ): + if is_mutates: + type_str = f"Tensor(a{idx}!)[]" + else: + type_str = "Tensor[]" + elif get_origin(param_type) in (list, List) and get_args(param_type)[0] is int: + type_str = "int[]" + elif param_type == Optional[torch.dtype]: + type_str = "ScalarType?" + else: + type_str = "*" + flag = False + if flag: + param_str = f"{type_str} {name}" + + if param.default != inspect.Parameter.empty: + if param.default is None: + param_str += "=None" + else: + param_str += f"={param.default}" + else: + param_str = f"{type_str} " + + parameters.append(param_str) + return_annotation = sig.return_annotation + return_type = "" + if return_annotation is type(None) or return_annotation is None: + return_type = "()" + elif return_annotation is torch.Tensor: + return_type = "Tensor" + elif ( + get_origin(return_annotation) is list and get_args(return_annotation)[0] is int + ): + return_type = "int[]" + elif return_annotation is int: + return_type = "int" + elif return_annotation is float: + return_type = "float" + elif return_annotation is bool: + return_type = "bool" + elif ( + get_origin(return_annotation) is list + and get_args(return_annotation)[0] is torch.Tensor + ): + return_type = "Tensor[]" + elif get_origin(return_annotation) is tuple: + args = get_args(return_annotation) + type_strings = [] + for arg in args: + if arg is torch.Tensor: + type_strings.append("Tensor") + elif arg is int: + type_strings.append("int") + elif arg is float: + type_strings.append("float") + elif arg is bool: + type_strings.append("bool") + return_type = f"({', '.join(type_strings)})" + else: + return_type = "Any" + + schema = f"({', '.join(parameters)}) -> {return_type}" + + return schema + + def torch_compile_guard( mutates_args: list[str] = [], device: str = "cpu", + calling_func_: Optional[Callable[..., Any]] = None, gen_fake: Optional[Callable[..., Any]] = None, ): def decorator(func): + # In core.py, we calling wrapper, but actually we need use aiter.op func + calling_func = calling_func_ if calling_func_ is not None else func + + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + try: import torch from torch.library import Library import inspect except ImportError: + return wrapper - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - + if calling_func.__name__ in NONE_WRAPPED_OP: return wrapper - global aiter_lib - aiter_lib = Library("aiter", "FRAGMENT") if aiter_lib is None else aiter_lib - op_name = func.__name__ - sig = inspect.signature(func) - return_annotation = sig.return_annotation - return_non_tensor = False - # Only return int/bool/float will cause graph breaks - if return_annotation in [int, bool, float]: - return_non_tensor = True + def wrapper_register(calling_func): + import inspect - def outer_wrapper(*args, **kwargs): - dummy = torch.empty(1, device=device) - if return_non_tensor: - result = getattr(torch.ops.aiter, op_name)(dummy, *args, **kwargs) - _, int_value = result - return int_value - return getattr(torch.ops.aiter, op_name)(dummy, *args, **kwargs) - - if hasattr(torch.ops.aiter, func.__name__): - return outer_wrapper - if hasattr(torch.library, "infer_schema"): - schema_str = torch.library.infer_schema(func, mutates_args=mutates_args) - else: - # for pytorch 2.4 - import torch._custom_op.impl + import torch + import torch.library + from torch.library import Library - schema_str = torch._custom_op.impl.infer_schema( - func, mutates_args=mutates_args - ) + global aiter_lib + aiter_lib = Library("aiter", "FRAGMENT") if aiter_lib is None else aiter_lib + schema = "" + if calling_func.__name__ in MANUAL_SCHEMA_OPS: + schema = generate_schema(calling_func) + else: + sig = inspect.signature(calling_func) + mutates_args = SPECIAL_OPS_MUTATES_ARGS.get( + calling_func.__name__, "unknown" + ) + if hasattr(torch.library, "infer_schema"): + sig = torch.library.infer_schema( + calling_func, mutates_args=mutates_args + ) + else: + # for pytorch 2.4 + import torch._custom_op.impl - input_part, output_part = schema_str.split("->", 1) - if not sig.parameters: - new_input = "(Tensor dummy)" - else: - new_input = "(Tensor dummy, " + input_part[1:] + # torch 2.4 not support mutates "unknown" for inplace all param + if mutates_args == "unknown": + mutates_args = [] + + for param_name, param in sig.parameters.items(): + if param.annotation == torch.Tensor: + mutates_args.append(param_name) - output_part = output_part.strip() - if not return_non_tensor: - new_output = output_part + sig = torch._custom_op.impl.infer_schema(calling_func, mutates_args) + schema = f"{sig}" + return schema + + schema = wrapper_register(calling_func) + + sig = inspect.signature(calling_func) + input_is_tensor = False + parameters = list(sig.parameters.values()) + + if parameters: + first_param = parameters[0] + if ( + first_param.annotation is not inspect.Parameter.empty + and first_param.annotation is torch.Tensor + ): + input_is_tensor = True + + input_part, output_part = schema.split("->", 1) + if input_is_tensor: + new_input = input_part else: - # return only int will cause graph breaks and we add dummy_out - new_output = "(Tensor, " + output_part + ")" - schema_str = f"{new_input} -> {new_output}".strip() - - def custom_impl(dummy_tensor, *args, **kwargs): - out = torch.empty(1, device=device) - if not return_non_tensor: - return func(*args, **kwargs) - return out, func(*args, **kwargs) - - def fake_impl(dummy_tensor, *args, **kwargs): - out = torch.empty(1, device=device) - if not return_non_tensor: - if gen_fake is not None: - return gen_fake(*args, **kwargs) - return func(*args, **kwargs) + if not sig.parameters: + new_input = "(Tensor dummy)" + else: + new_input = "(Tensor dummy, " + input_part[1:] + + return_int = False + return_annotation = sig.return_annotation + if return_annotation is int: + output_part = "(Tensor, " + output_part + ")" + return_int = True + schema = f"{new_input} -> {output_part}".strip() + + loadName = calling_func.__name__ + + def abstract_impl(*args, custom_build_args={}, **kwargs): + if return_int: + return torch.empty(1, device=device), 1 if gen_fake is not None: - return out, gen_fake(*args, **kwargs) - return out, func(*args, **kwargs) + return gen_fake(*args, **kwargs) + return calling_func(*args, **kwargs) - if is_torch_equal_or_newer("2.8.0"): - tags = () - else: - tags = (torch.Tag.needs_fixed_stride_order,) + def outer_wrapper(*args, **kwargs): + return ( + wrapper(*args, **kwargs) + if not return_int + else (torch.empty(1, device=device), wrapper(*args, **kwargs)) + ) + + def abstract_impl_dummy(dummy, *args, custom_build_args={}, **kwargs): + if return_int: + return torch.empty(1, device=device), 1 + if gen_fake is not None: + return gen_fake(*args, **kwargs) + return calling_func(*args, **kwargs) + + def outer_wrapper_dummy(dummy, *args, **kwargs): + return ( + wrapper(*args, **kwargs) + if not return_int + else (torch.empty(1, device=device), wrapper(*args, **kwargs)) + ) - my_lib = aiter_lib - my_lib.define(op_name + schema_str, tags=tags) - my_lib.impl(op_name, custom_impl, dispatch_key="CUDA") - my_lib.impl(op_name, custom_impl, dispatch_key="CPU") - my_lib._register_fake(op_name, fake_impl) + custom_func = outer_wrapper + fake_func = abstract_impl + if not input_is_tensor: + custom_func = outer_wrapper_dummy + fake_func = abstract_impl_dummy + + if not hasattr(torch.ops.aiter, calling_func.__name__): + if is_torch_equal_or_newer("2.8.0"): + tags = () + else: + tags = (torch.Tag.needs_fixed_stride_order,) + op_schema = f"aiter::{loadName}" + schema + aiter_lib.define(op_schema, tags=tags) + aiter_lib.impl(f"aiter::{loadName}", custom_func, dispatch_key="CUDA") + aiter_lib.impl(f"aiter::{loadName}", custom_func, dispatch_key="CPU") + aiter_lib._register_fake(f"{loadName}", fake_func) + + def wrapper_custom(*args, custom_build_args={}, **kwargs): + result = ( + getattr(torch.ops.aiter, f"{loadName}")(*args, **kwargs) + if input_is_tensor + else getattr(torch.ops.aiter, f"{loadName}")( + torch.empty(1, device=device), *args, **kwargs + ) + ) + return result[1] if return_int else result - return outer_wrapper + return wrapper_custom return decorator