Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions aiter/dist/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ def all_reduce_fake(
return torch.empty_like(tensor)


# There is same name all_reduce in aiter.op, use Alias
@torch_compile_guard(gen_fake=all_reduce_fake)
def all_reduce(
def all_reduce_(
tensor: torch.Tensor, group_name: str, ca_fp8_quant: bool
) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
Expand Down Expand Up @@ -317,7 +318,7 @@ def all_reduce(
if self.world_size == 1:
return input_

return all_reduce(
return all_reduce_(
input_, group_name=self.unique_name, ca_fp8_quant=ca_fp8_quant
)

Expand Down
294 changes: 6 additions & 288 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import traceback
import types
import typing
from typing import Any, Callable, List, Optional, Union, get_args, get_origin
from typing import Any, Callable, List, Optional

from packaging.version import Version, parse

Expand All @@ -23,7 +23,7 @@
from chip_info import get_gfx
from cpp_extension import _jit_compile, get_hip_version
from file_baton import FileBaton
from torch_guard import is_torch_equal_or_newer, torch_compile_guard # noqa: E402
from torch_guard import torch_compile_guard # noqa: E402

AITER_REBUILD = int(os.environ.get("AITER_REBUILD", "0"))

Expand Down Expand Up @@ -664,160 +664,6 @@ def convert(d_ops: dict):
)


MANUAL_SCHEMA_OPS = [
"register_graph_buffers",
"module_moe_ck2stages",
"mha_fwd",
"fmha_v3_fwd",
"mha_varlen_fwd",
"mha_bwd",
"fmha_v3_bwd",
"mha_varlen_bwd",
"fmha_v3_varlen_bwd",
"fmha_v3_varlen_fwd",
"mha_batch_prefill",
"hipb_findallsols",
"rocb_findallsols",
"_ActivationType",
"_QuantType",
"init_custom_ar",
"greedy_sample",
"random_sample_outer_exponential",
"random_sample",
"mixed_sample",
"exponential",
]

NONE_WRAPPED_OP = [
# "hipb_create_extension",
# "hipb_destroy_extension",
"getHipblasltKernelName",
# "rocb_create_extension",
# "rocb_destroy_extension",
"get_graph_buffer_ipc_meta",
"_ActivationType",
"_QuantType",
# "dispose",
# "meta_size",
# "get_padded_m",
"compile_mha_fwd",
"compile_mha_bwd",
"init_custom_qr",
"qr_max_size",
"qr_destroy",
"qr_open_handles",
"qr_get_handle",
]

# We default all args are inplace, you can define inplace args for specific op
SPECIAL_OPS_MUTATES_ARGS = {}


def generate_schema(func) -> str:
import inspect

import torch

sig = inspect.signature(func)
parameters = []
mutates_args = SPECIAL_OPS_MUTATES_ARGS.get(func.__name__, [])
for idx, (name, param) in enumerate(sig.parameters.items()):
param_type = param.annotation
flag = True
is_mutates = True
if len(mutates_args) > 0 and name not in mutates_args:
is_mutates = False

if param_type is torch.Tensor:
if is_mutates:
type_str = f"Tensor(a{idx}!)"
else:
type_str = "Tensor"
elif param_type == Optional[torch.Tensor]:
if is_mutates:
type_str = f"Tensor(a{idx}!)?"
else:
type_str = "Tensor?"
elif get_origin(param_type) is Union and torch.Tensor in get_args(param_type):
if is_mutates:
type_str = f"Tensor(a{idx}!)?"
else:
type_str = "Tensor?"
elif param_type in (torch.SymInt, int):
type_str = "SymInt"
elif param_type in (float, bool, str):
type_str = param_type.__name__
elif param_type == Optional[torch.Generator]:
type_str = "Generator?"
elif (
get_origin(param_type) in (list, List)
and get_args(param_type)[0] is torch.Tensor
):
if is_mutates:
type_str = f"Tensor(a{idx}!)[]"
else:
type_str = "Tensor[]"
elif get_origin(param_type) in (list, List) and get_args(param_type)[0] is int:
type_str = "int[]"
elif param_type == Optional[torch.dtype]:
type_str = "ScalarType?"
else:
type_str = "*"
flag = False
if flag:
param_str = f"{type_str} {name}"

if param.default != inspect.Parameter.empty:
if param.default is None:
param_str += "=None"
else:
param_str += f"={param.default}"
else:
param_str = f"{type_str} "

parameters.append(param_str)
return_annotation = sig.return_annotation
return_type = ""
if return_annotation is type(None) or return_annotation is None:
return_type = "()"
elif return_annotation is torch.Tensor:
return_type = "Tensor"
elif (
get_origin(return_annotation) is list and get_args(return_annotation)[0] is int
):
return_type = "int[]"
elif return_annotation is int:
return_type = "int"
elif return_annotation is float:
return_type = "float"
elif return_annotation is bool:
return_type = "bool"
elif (
get_origin(return_annotation) is list
and get_args(return_annotation)[0] is torch.Tensor
):
return_type = "Tensor[]"
elif get_origin(return_annotation) is tuple:
args = get_args(return_annotation)
type_strings = []
for arg in args:
if arg is torch.Tensor:
type_strings.append("Tensor")
elif arg is int:
type_strings.append("int")
elif arg is float:
type_strings.append("float")
elif arg is bool:
type_strings.append("bool")
return_type = f"({', '.join(type_strings)})"
else:
return_type = "Any"

schema = f"({', '.join(parameters)}) -> {return_type}"

return schema


def compile_ops(
_md_name: str,
fc_name: Optional[str] = None,
Expand Down Expand Up @@ -986,138 +832,10 @@ def check_args():

return op(*args, **kwargs)

if func.__name__ in NONE_WRAPPED_OP:
return wrapper

def wrapper_register(func):
import inspect

import torch
import torch.library
from torch.library import Library

global aiter_lib
aiter_lib = Library("aiter", "FRAGMENT") if aiter_lib is None else aiter_lib
schema = ""
if func.__name__ in MANUAL_SCHEMA_OPS:
schema = generate_schema(func)
else:
sig = inspect.signature(func)
mutates_args = SPECIAL_OPS_MUTATES_ARGS.get(func.__name__, "unknown")
if hasattr(torch.library, "infer_schema"):
sig = torch.library.infer_schema(func, mutates_args=mutates_args)
else:
# for pytorch 2.4
import torch._custom_op.impl

# torch 2.4 not support mutates "unknown" for inplace all param
if mutates_args == "unknown":
mutates_args = []

for param_name, param in sig.parameters.items():
if param.annotation == torch.Tensor:
mutates_args.append(param_name)

sig = torch._custom_op.impl.infer_schema(func, mutates_args)
schema = f"{sig}"
return schema

schema = wrapper_register(func)

import inspect

import torch

sig = inspect.signature(func)
input_is_tensor = False
parameters = list(sig.parameters.values())

if parameters:
first_param = parameters[0]
if (
first_param.annotation is not inspect.Parameter.empty
and first_param.annotation is torch.Tensor
):
input_is_tensor = True

input_part, output_part = schema.split("->", 1)
if input_is_tensor:
new_input = input_part
else:
if not sig.parameters:
new_input = "(Tensor dummy)"
else:
new_input = "(Tensor dummy, " + input_part[1:]

return_int = False
return_annotation = sig.return_annotation
if return_annotation is int:
output_part = "(Tensor, " + output_part + ")"
return_int = True

schema = f"{new_input} -> {output_part}".strip()

loadName = func.__name__

def abstract_impl(*args, custom_build_args={}, **kwargs):
if return_int:
return torch.empty(1, device="cuda"), 1
if gen_fake is not None:
return gen_fake(*args, **kwargs)
return func(*args, **kwargs)

def outer_wrapper(*args, **kwargs):
return (
wrapper(*args, **kwargs)
if not return_int
else (torch.empty(1, device="cuda"), wrapper(*args, **kwargs))
)

def abstract_impl_dummy(dummy, *args, custom_build_args={}, **kwargs):
if return_int:
return torch.empty(1, device="cuda"), 1
if gen_fake is not None:
return gen_fake(*args, **kwargs)
return func(*args, **kwargs)

def outer_wrapper_dummy(dummy, *args, **kwargs):
return (
wrapper(*args, **kwargs)
if not return_int
else (torch.empty(1, device="cuda"), wrapper(*args, **kwargs))
)

custom_func = outer_wrapper
fake_func = abstract_impl
if not input_is_tensor:
custom_func = outer_wrapper_dummy
fake_func = abstract_impl_dummy

if not hasattr(torch.ops.aiter, f"wrapper_{loadName}"):
if is_torch_equal_or_newer("2.8.0"):
tags = ()
else:
tags = (torch.Tag.needs_fixed_stride_order,)
op_schema = f"aiter::wrapper_{loadName}" + schema
aiter_lib.define(op_schema, tags=tags)
aiter_lib.impl(
f"aiter::wrapper_{loadName}", custom_func, dispatch_key="CUDA"
)
aiter_lib.impl(
f"aiter::wrapper_{loadName}", custom_func, dispatch_key="CPU"
)
aiter_lib._register_fake(f"wrapper_{loadName}", fake_func)

def wrapper_custom(*args, custom_build_args={}, **kwargs):
result = (
getattr(torch.ops.aiter, f"wrapper_{loadName}")(*args, **kwargs)
if input_is_tensor
else getattr(torch.ops.aiter, f"wrapper_{loadName}")(
torch.empty(1, device="cuda"), *args, **kwargs
)
)
return result[1] if return_int else result
@torch_compile_guard(device="cuda", gen_fake=gen_fake, calling_func_=func)
def custom_wrapper(*args, **kwargs):
return wrapper(*args, **kwargs)

return wrapper_custom
return custom_wrapper

return decorator
Loading