diff --git a/aiter/jit/utils/torch_guard.py b/aiter/jit/utils/torch_guard.py index cb969c9bce..21eef959ea 100644 --- a/aiter/jit/utils/torch_guard.py +++ b/aiter/jit/utils/torch_guard.py @@ -78,23 +78,19 @@ def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: "qr_get_handle", ] -# We default all args are inplace, you can define inplace args for specific op -SPECIAL_OPS_MUTATES_ARGS = {} - -def generate_schema(func) -> str: +def generate_schema(func, mutates_args: Union[list[str], str] = "unknown") -> str: import inspect import torch sig = inspect.signature(func) parameters = [] - mutates_args = SPECIAL_OPS_MUTATES_ARGS.get(func.__name__, []) for idx, (name, param) in enumerate(sig.parameters.items()): param_type = param.annotation flag = True is_mutates = True - if len(mutates_args) > 0 and name not in mutates_args: + if mutates_args != "unknown" and name not in mutates_args: is_mutates = False if param_type is torch.Tensor: @@ -188,7 +184,7 @@ def generate_schema(func) -> str: def torch_compile_guard( - mutates_args: list[str] = [], + mutates_args: Union[list[str], str] = "unknown", device: str = "cpu", calling_func_: Optional[Callable[..., Any]] = None, gen_fake: Optional[Callable[..., Any]] = None, @@ -224,11 +220,8 @@ def wrapper_register(calling_func): schema = generate_schema(calling_func) else: sig = inspect.signature(calling_func) - mutates_args = SPECIAL_OPS_MUTATES_ARGS.get( - calling_func.__name__, "unknown" - ) if hasattr(torch.library, "infer_schema"): - sig = torch.library.infer_schema( + schema = torch.library.infer_schema( calling_func, mutates_args=mutates_args ) else: @@ -237,14 +230,15 @@ def wrapper_register(calling_func): # torch 2.4 not support mutates "unknown" for inplace all param if mutates_args == "unknown": - mutates_args = [] + mutates_args_custom = [] for param_name, param in sig.parameters.items(): if param.annotation == torch.Tensor: - mutates_args.append(param_name) + mutates_args_custom.append(param_name) - sig = torch._custom_op.impl.infer_schema(calling_func, mutates_args) - schema = f"{sig}" + schema = torch._custom_op.impl.infer_schema( + calling_func, mutates_args_custom + ) return schema schema = wrapper_register(calling_func) @@ -280,11 +274,27 @@ def wrapper_register(calling_func): loadName = calling_func.__name__ - def abstract_impl(*args, custom_build_args={}, **kwargs): - if return_non_tensor: - return torch.empty(1, device=device), 1 + def wrapper_custom(*args, **kwargs): + result = ( + getattr(torch.ops.aiter, f"{loadName}")(*args, **kwargs) + if input_is_tensor + else getattr(torch.ops.aiter, f"{loadName}")( + torch.empty(1, device=device), *args, **kwargs + ) + ) + return result[1] if return_non_tensor else result + + if hasattr(torch.ops.aiter, loadName): + return wrapper_custom + + def abstract_impl(*args, **kwargs): if gen_fake is not None: - return gen_fake(*args, **kwargs) + if return_non_tensor: + return torch.empty(1, device=device), gen_fake(*args, **kwargs) + else: + return gen_fake(*args, **kwargs) + if return_non_tensor: + return torch.empty(1, device=device), calling_func(*args, **kwargs) return calling_func(*args, **kwargs) def outer_wrapper(*args, **kwargs): @@ -294,11 +304,14 @@ def outer_wrapper(*args, **kwargs): else (torch.empty(1, device=device), wrapper(*args, **kwargs)) ) - def abstract_impl_dummy(dummy, *args, custom_build_args={}, **kwargs): - if return_non_tensor: - return torch.empty(1, device=device), 1 + def abstract_impl_dummy(dummy, *args, **kwargs): if gen_fake is not None: - return gen_fake(*args, **kwargs) + if return_non_tensor: + return torch.empty(1, device=device), gen_fake(*args, **kwargs) + else: + return gen_fake(*args, **kwargs) + if return_non_tensor: + return torch.empty(1, device=device), calling_func(*args, **kwargs) return calling_func(*args, **kwargs) def outer_wrapper_dummy(dummy, *args, **kwargs): @@ -325,16 +338,6 @@ def outer_wrapper_dummy(dummy, *args, **kwargs): aiter_lib.impl(f"aiter::{loadName}", custom_func, dispatch_key="CPU") aiter_lib._register_fake(f"{loadName}", fake_func) - def wrapper_custom(*args, custom_build_args={}, **kwargs): - result = ( - getattr(torch.ops.aiter, f"{loadName}")(*args, **kwargs) - if input_is_tensor - else getattr(torch.ops.aiter, f"{loadName}")( - torch.empty(1, device=device), *args, **kwargs - ) - ) - return result[1] if return_non_tensor else result - return wrapper_custom return decorator diff --git a/aiter/ops/gemm_op_a4w4.py b/aiter/ops/gemm_op_a4w4.py index 21832861cf..8ec03cea1e 100644 --- a/aiter/ops/gemm_op_a4w4.py +++ b/aiter/ops/gemm_op_a4w4.py @@ -2,9 +2,9 @@ # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. import functools -import os from typing import Optional +from aiter.jit.utils.torch_guard import torch_compile_guard import pandas as pd import torch from torch import Tensor @@ -14,7 +14,6 @@ from ..jit.core import ( AITER_CONFIG_GEMM_A4W4_FILE, AITER_LOG_TUNED_CONFIG, - AITER_ROOT_DIR, compile_ops, ) from ..jit.utils.chip_info import get_cu_num, get_gfx @@ -60,6 +59,21 @@ def get_GEMM_config(M: int, N: int, K: int): return config +def gemm_a4w4_fake( + A: Tensor, # A:[M, K/2] f4x2 + B: Tensor, # B:[N, K/2] f4x2 + A_scale: Tensor, # A_scale:[M, K/32] e8m0 paded + B_scale: Tensor, # B_scale:[N, K/32] e8m0 paded + out: Tensor, # Out:[M, N] bf16 + bias: Optional[Tensor] = None, # bias:[1, N] f32 + alpha: Optional[float] = 1.0, + beta: Optional[float] = 0.0, + bpreshuffle: Optional[bool] = True, +) -> torch.Tensor: + return out + + +@torch_compile_guard(gen_fake=gemm_a4w4_fake) def gemm_a4w4( A: Tensor, # A:[M, K/2] f4x2 B: Tensor, # B:[N, K/2] f4x2 diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index 16dc8faca7..db81c45f38 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -207,8 +207,8 @@ def compute_gemm_SplitK(M: int, N: int, K: int, tile_m: int, tile_n: int, tile_k _CKGEMM_CONFIG_CACHE = None -@torch_compile_guard() -def get_CKGEMM_config_(tuned_file: str = None) -> None: +@functools.lru_cache(maxsize=1024) +def get_CKGEMM_config(M: int, N: int, K: int, tuned_file="a8w8_tuned_gemm.csv"): if tuned_file is None: tuned_file = "a8w8_tuned_gemm.csv" global _CKGEMM_CONFIG_CACHE @@ -221,13 +221,6 @@ def get_CKGEMM_config_(tuned_file: str = None) -> None: ["cu_num", "M", "N", "K"] ).to_dict("index") - return None - - -@functools.lru_cache(maxsize=1024) -def get_CKGEMM_config(M: int, N: int, K: int, tuned_file="a8w8_tuned_gemm.csv"): - get_CKGEMM_config_(tuned_file) - cu_num = get_cu_num() padded_M = M @@ -277,15 +270,28 @@ def get_bpreshuffle_GEMM_config( return config +def gemm_a8w8_fake( + XQ: Tensor, + WQ: Tensor, + x_scale: Tensor, + w_scale: Tensor, + bias: Optional[Tensor] = None, + dtype: torch.dtype = dtypes.bf16, + splitK: Optional[int] = None, +) -> Tensor: + return torch.empty(XQ.shape[0], WQ.shape[0], dtype=dtype, device=XQ.device) + + +@torch_compile_guard(gen_fake=gemm_a8w8_fake) def gemm_a8w8( XQ: Tensor, WQ: Tensor, x_scale: Tensor, w_scale: Tensor, bias: Optional[Tensor] = None, - dtype=dtypes.bf16, + dtype: torch.dtype = dtypes.bf16, splitK: Optional[int] = None, -): +) -> Tensor: # assert dtype in [ # dtypes.bf16, # dtypes.fp16, @@ -350,9 +356,9 @@ def gemm_a8w8_CK( x_scale: Tensor, w_scale: Tensor, bias: Optional[Tensor] = None, - dtype=dtypes.bf16, + dtype: torch.dtype = dtypes.bf16, splitK: Optional[int] = None, -): +) -> Tensor: # assert dtype in [ # dtypes.bf16, # dtypes.fp16, @@ -370,15 +376,28 @@ def gemm_a8w8_CK( return gemm_a8w8_ck(XQ, WQ, x_scale, w_scale, Y, bias, splitK) +def gemm_a8w8_bpreshuffle_fake( + XQ: Tensor, + WQ: Tensor, + x_scale: Tensor, + w_scale: Tensor, + bias: Optional[Tensor] = None, + dtype: torch.dtype = dtypes.bf16, + check: bool = False, +) -> Tensor: + return torch.empty(XQ.shape[0], WQ.shape[0], dtype=dtype, device=XQ.device) + + +@torch_compile_guard(gen_fake=gemm_a8w8_bpreshuffle_fake) def gemm_a8w8_bpreshuffle( XQ: Tensor, WQ: Tensor, x_scale: Tensor, w_scale: Tensor, bias: Optional[Tensor] = None, - dtype=torch.float16, - check=False, -): + dtype: torch.dtype = dtypes.bf16, + check: bool = False, +) -> Tensor: assert dtype in [ torch.bfloat16, torch.float16, @@ -410,7 +429,7 @@ def gemm_a8w8_blockscale_fake( WQ: Tensor, x_scale: Tensor, w_scale: Tensor, - dtype=dtypes.bf16, + dtype: torch.dtype = dtypes.bf16, isBpreshuffled=False, ) -> torch.Tensor: m = XQ.shape[0] @@ -465,9 +484,24 @@ def flatmm_a8w8_blockscale_ASM( return flatmm_a8w8_blockscale_asm(XQ, WQ, x_scale, w_scale, Y) +def gemm_a8w8_blockscale_bpreshuffle_fake( + XQ: Tensor, + WQ: Tensor, + x_scale: Tensor, + w_scale: Tensor, + dtype: torch.dtype = dtypes.bf16, +) -> Tensor: + return torch.empty(XQ.shape[0], WQ.shape[0], dtype=dtype, device=XQ.device) + + +@torch_compile_guard(gen_fake=gemm_a8w8_blockscale_bpreshuffle_fake) def gemm_a8w8_blockscale_bpreshuffle( - XQ: Tensor, WQ: Tensor, x_scale: Tensor, w_scale: Tensor, dtype=dtypes.bf16 -): + XQ: Tensor, + WQ: Tensor, + x_scale: Tensor, + w_scale: Tensor, + dtype: torch.dtype = dtypes.bf16, +) -> Tensor: assert dtype in [ dtypes.bf16, dtypes.fp16, diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index ebf7f6f28e..4f3943d725 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1257,7 +1257,7 @@ def _validate_cu(name: str, x: Optional[torch.Tensor]): return out, softmax_lse, S_dmask, rng_state -@torch_compile_guard() +# @torch_compile_guard(mutates_args=[]) def can_impl_fmha_v3_bwd( dout: torch.Tensor, q: torch.Tensor, @@ -1436,6 +1436,42 @@ def psskddv(): return ret +def _flash_attn_backward_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + dbias: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + bias: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + deterministic: bool, + rng_state: Optional[torch.Tensor] = None, + is_v3_atomic_fp32: Optional[bool] = True, + how_v3_bf16_cvt: Optional[int] = 1, +) -> torch.Tensor: + batch_size = q.size(0) + seqlen_q = q.size(1) + num_heads = q.size(2) + + softmax_d = torch.empty( + (batch_size, num_heads, seqlen_q), # {batch_size, num_heads, seqlen_q} + dtype=torch.float32, + device=q.device, + ) + return softmax_d + + +@torch_compile_guard(gen_fake=_flash_attn_backward_fake) def _flash_attn_backward( dout: torch.Tensor, q: torch.Tensor,