Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,7 +1543,56 @@ def check_args():
func_hints = typing.get_type_hints(func)
if ann["return"] is None:
func_hints["return"] = None
if ann != func_hints:

tensor_like_types = {torch.Tensor}
if aiter_tensor_t is not object:
tensor_like_types.add(aiter_tensor_t)

def canonicalize_hint(hint):
if hint in tensor_like_types:
return ("tensor",)

origin = typing.get_origin(hint)
if origin in (list, List):
return (
"list",
tuple(
canonicalize_hint(arg)
for arg in typing.get_args(hint)
),
)
if origin is tuple:
return (
"tuple",
tuple(
canonicalize_hint(arg)
for arg in typing.get_args(hint)
),
)
if origin in (typing.Union, types.UnionType):
Comment thread
amd-ruitang3 marked this conversation as resolved.
return (
"union",
tuple(
sorted(
(
canonicalize_hint(arg)
for arg in typing.get_args(hint)
),
key=repr,
)
),
)
return hint

canonical_ann = {
key: canonicalize_hint(value) for key, value in ann.items()
}
canonical_func_hints = {
key: canonicalize_hint(value)
for key, value in func_hints.items()
}

if canonical_ann != canonical_func_hints:
logger.warning(
f"type hints mismatch, override to --> {doc_str}"
)
Expand Down
40 changes: 21 additions & 19 deletions aiter/ops/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from typing import List

import torch

from ..jit.core import compile_ops

MD_NAME = "module_custom_all_reduce"
Expand All @@ -23,8 +25,8 @@ def init_custom_ar(
@compile_ops("module_custom_all_reduce", develop=True)
def all_reduce(
_fa: int,
inp,
out,
inp: torch.Tensor,
out: torch.Tensor,
use_new: bool,
open_fp8_quant: bool,
reg_inp_ptr: int,
Expand All @@ -35,8 +37,8 @@ def all_reduce(
@compile_ops("module_custom_all_reduce", develop=True)
def reduce_scatter(
_fa: int,
inp,
out,
inp: torch.Tensor,
out: torch.Tensor,
reg_ptr: int,
reg_bytes: int,
) -> None: ...
Expand All @@ -45,18 +47,18 @@ def reduce_scatter(
@compile_ops("module_custom_all_reduce", develop=True)
def all_gather_reg(
_fa: int,
inp,
out,
inp: torch.Tensor,
out: torch.Tensor,
dim: int,
) -> None: ...


@compile_ops("module_custom_all_reduce", develop=True)
def all_gather_unreg(
_fa: int,
inp,
inp: torch.Tensor,
reg_buffer: int,
out,
out: torch.Tensor,
reg_bytes: int,
dim: int,
) -> None: ...
Expand All @@ -65,11 +67,11 @@ def all_gather_unreg(
@compile_ops("module_custom_all_reduce", develop=True)
def fused_allreduce_rmsnorm(
_fa: int,
inp,
res_inp,
res_out,
out,
w,
inp: torch.Tensor,
res_inp: torch.Tensor,
res_out: torch.Tensor,
out: torch.Tensor,
w: torch.Tensor,
eps: float,
reg_ptr: int,
reg_bytes: int,
Expand All @@ -80,12 +82,12 @@ def fused_allreduce_rmsnorm(
@compile_ops("module_custom_all_reduce", develop=True)
def fused_allreduce_rmsnorm_quant(
_fa: int,
inp,
res_inp,
res_out,
out,
scale_out,
w,
inp: torch.Tensor,
res_inp: torch.Tensor,
res_out: torch.Tensor,
out: torch.Tensor,
scale_out: torch.Tensor,
w: torch.Tensor,
eps: float,
reg_ptr: int,
reg_bytes: int,
Expand Down
Loading