From 2f5489783b8f4c927c436c858412a37d37488717 Mon Sep 17 00:00:00 2001 From: amd-ruitang3 Date: Mon, 13 Apr 2026 22:26:31 -0500 Subject: [PATCH] [aiter] type hints mismatch --- aiter/jit/core.py | 51 +++++++++++++++++++++++++++++++++- aiter/ops/custom_all_reduce.py | 40 +++++++++++++------------- 2 files changed, 71 insertions(+), 20 deletions(-) diff --git a/aiter/jit/core.py b/aiter/jit/core.py index 098af6ce6a..d3da4b538e 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -1543,7 +1543,56 @@ def check_args(): func_hints = typing.get_type_hints(func) if ann["return"] is None: func_hints["return"] = None - if ann != func_hints: + + tensor_like_types = {torch.Tensor} + if aiter_tensor_t is not object: + tensor_like_types.add(aiter_tensor_t) + + def canonicalize_hint(hint): + if hint in tensor_like_types: + return ("tensor",) + + origin = typing.get_origin(hint) + if origin in (list, List): + return ( + "list", + tuple( + canonicalize_hint(arg) + for arg in typing.get_args(hint) + ), + ) + if origin is tuple: + return ( + "tuple", + tuple( + canonicalize_hint(arg) + for arg in typing.get_args(hint) + ), + ) + if origin in (typing.Union, types.UnionType): + return ( + "union", + tuple( + sorted( + ( + canonicalize_hint(arg) + for arg in typing.get_args(hint) + ), + key=repr, + ) + ), + ) + return hint + + canonical_ann = { + key: canonicalize_hint(value) for key, value in ann.items() + } + canonical_func_hints = { + key: canonicalize_hint(value) + for key, value in func_hints.items() + } + + if canonical_ann != canonical_func_hints: logger.warning( f"type hints mismatch, override to --> {doc_str}" ) diff --git a/aiter/ops/custom_all_reduce.py b/aiter/ops/custom_all_reduce.py index a7f74475be..d4cb4ddea9 100644 --- a/aiter/ops/custom_all_reduce.py +++ b/aiter/ops/custom_all_reduce.py @@ -3,6 +3,8 @@ from typing import List +import torch + from ..jit.core import compile_ops MD_NAME = "module_custom_all_reduce" @@ -23,8 +25,8 @@ def init_custom_ar( @compile_ops("module_custom_all_reduce", develop=True) def all_reduce( _fa: int, - inp, - out, + inp: torch.Tensor, + out: torch.Tensor, use_new: bool, open_fp8_quant: bool, reg_inp_ptr: int, @@ -35,8 +37,8 @@ def all_reduce( @compile_ops("module_custom_all_reduce", develop=True) def reduce_scatter( _fa: int, - inp, - out, + inp: torch.Tensor, + out: torch.Tensor, reg_ptr: int, reg_bytes: int, ) -> None: ... @@ -45,8 +47,8 @@ def reduce_scatter( @compile_ops("module_custom_all_reduce", develop=True) def all_gather_reg( _fa: int, - inp, - out, + inp: torch.Tensor, + out: torch.Tensor, dim: int, ) -> None: ... @@ -54,9 +56,9 @@ def all_gather_reg( @compile_ops("module_custom_all_reduce", develop=True) def all_gather_unreg( _fa: int, - inp, + inp: torch.Tensor, reg_buffer: int, - out, + out: torch.Tensor, reg_bytes: int, dim: int, ) -> None: ... @@ -65,11 +67,11 @@ def all_gather_unreg( @compile_ops("module_custom_all_reduce", develop=True) def fused_allreduce_rmsnorm( _fa: int, - inp, - res_inp, - res_out, - out, - w, + inp: torch.Tensor, + res_inp: torch.Tensor, + res_out: torch.Tensor, + out: torch.Tensor, + w: torch.Tensor, eps: float, reg_ptr: int, reg_bytes: int, @@ -80,12 +82,12 @@ def fused_allreduce_rmsnorm( @compile_ops("module_custom_all_reduce", develop=True) def fused_allreduce_rmsnorm_quant( _fa: int, - inp, - res_inp, - res_out, - out, - scale_out, - w, + inp: torch.Tensor, + res_inp: torch.Tensor, + res_out: torch.Tensor, + out: torch.Tensor, + scale_out: torch.Tensor, + w: torch.Tensor, eps: float, reg_ptr: int, reg_bytes: int,