diff --git a/aiter/jit/core.py b/aiter/jit/core.py index ef9776e680..068a17d687 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -761,6 +761,13 @@ def check_args(): doc_str = op.__doc__.split("\n")[0] doc_str = re.sub(r"<(.*?)\:.*?>", r"\g<1>", doc_str) doc_str = doc_str.replace("list[", "List[") + doc_str = doc_str.replace("tuple[", "Tuple[") + doc_str = doc_str.replace("collections.abc.Sequence[", "List[") + doc_str = doc_str.replace("typing.SupportsInt", "int") + doc_str = doc_str.replace("typing.SupportsFloat", "float") + # A|None --> Optional[A] + pattern = r"([\w\.]+(?:\[[^\]]+\])?)\s*\|\s*None" + doc_str = re.sub(pattern, r"Optional[\1]", doc_str) for el in enum_types: doc_str = re.sub(f" aiter.*{el} ", f" {el} ", doc_str) namespace = { @@ -769,9 +776,7 @@ def check_args(): "torch": torch, "typing": typing, } - if sys.version_info < (3, 10): - pattern = r"([\w\.]+(?:\[[^\]]+\])?)\s*\|\s*None" - doc_str = re.sub(pattern, r"Optional[\1]", doc_str) + exec( f"from aiter import*\ndef {doc_str}: pass", namespace, diff --git a/aiter/jit/utils/cpp_extension.py b/aiter/jit/utils/cpp_extension.py index 5f66a477a9..21bc3383b8 100644 --- a/aiter/jit/utils/cpp_extension.py +++ b/aiter/jit/utils/cpp_extension.py @@ -1551,7 +1551,7 @@ def _write_ninja_file_to_build_library( # But we can't use this now because all aiter op based on torch # which means pybind11 related build flags must from torch now common_cflags = [] - if torch_exclude and is_python_module: + if is_python_module: import pybind11 extra_include_paths.append(pybind11.get_include()) @@ -1575,9 +1575,9 @@ def _write_ninja_file_to_build_library( common_cflags.append(f"-DTORCH_EXTENSION_NAME={name}") else: common_cflags.append(f"-DTORCH_EXTENSION_NAME=aiter_") - common_cflags.append("-DTORCH_API_INCLUDE_EXTENSION_H") - common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()] - common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()] + # common_cflags.append("-DTORCH_API_INCLUDE_EXTENSION_H") + # common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()] + # common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()] # Windows does not understand `-isystem` and quotes flags later. common_cflags += [f"-I{shlex.quote(include)}" for include in user_includes] diff --git a/aiter/ops/custom_all_reduce.py b/aiter/ops/custom_all_reduce.py index be62dae8d3..e11bbb23b2 100644 --- a/aiter/ops/custom_all_reduce.py +++ b/aiter/ops/custom_all_reduce.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -from typing import List, Optional +from typing import List, Optional, Tuple import torch @@ -184,7 +184,7 @@ def register_buffer( @compile_ops("module_custom_all_reduce") -def get_graph_buffer_ipc_meta(_fa: int) -> tuple[torch.Tensor, torch.Tensor]: ... +def get_graph_buffer_ipc_meta(_fa: int) -> Tuple[torch.Tensor, torch.Tensor]: ... @compile_ops("module_custom_all_reduce") diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index bc401538c5..e72bef277e 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -from torch import Tensor, Generator -from typing import Optional, Tuple, Any -from ..jit.core import compile_ops, CK_DIR, AITER_CSRC_DIR +from typing import Any, Optional, Tuple + +import torch +from torch import Generator, Tensor + +from ..jit.core import AITER_CSRC_DIR, CK_DIR, compile_ops from ..jit.utils.chip_info import get_gfx from ..jit.utils.torch_guard import torch_compile_guard from ..utility import dtypes -import torch def cmdGenFunc_mha_fwd( diff --git a/aiter/ops/norm.py b/aiter/ops/norm.py index 2a47f35146..73d2f40b11 100644 --- a/aiter/ops/norm.py +++ b/aiter/ops/norm.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional import torch from torch import Tensor -from typing import Optional + from ..jit.core import compile_ops MD_NAME = "module_norm" @@ -43,8 +45,8 @@ def layer_norm( def layernorm2d_fwd( input: Tensor, # normalized_shape: List[int], - weight: Optional[Tensor] = None, - bias: Optional[Tensor] = None, + weight: Tensor, + bias: Tensor, epsilon: float = 1e-5, x_bias: Optional[Tensor] = None, ) -> Tensor: ... diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.h b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.h index 25eec9e0de..7c22ab857b 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.h +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.h @@ -2,21 +2,21 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" #include "aiter_enum.h" -#include "ck/utility/blkgemmpipe_scheduler.hpp" #include "py_itfs_common.h" #include +#include #include template diff --git a/csrc/include/binary_operator.cuh b/csrc/include/binary_operator.cuh index c1955f8975..0a85eee6fe 100644 --- a/csrc/include/binary_operator.cuh +++ b/csrc/include/binary_operator.cuh @@ -15,11 +15,12 @@ * limitations under the License. */ #pragma once -#include +#include "dispatch_utils.h" +#include "hip_compat.h" #include #include -#include "hip_compat.h" -#include "dispatch_utils.h" +#include +#include #include #include diff --git a/csrc/include/ck_tile/vec_convert.h b/csrc/include/ck_tile/vec_convert.h index 09b4c9edd9..e112846da7 100644 --- a/csrc/include/ck_tile/vec_convert.h +++ b/csrc/include/ck_tile/vec_convert.h @@ -76,6 +76,8 @@ CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_f32(fp32_t a, fp32_t b, // permute high bits and low bits to match the order of the original vector asm volatile("v_cvt_scalef32_pk_fp4_f32 %0, %1, %2, %3" : "=v"(c) : "v"(b), "v"(a), "v"(scale)); return bit_cast(bit_cast(c[0])[0]); +#else + return fp4x2_t{}; #endif } CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_f16(fp16x2_v a, fp32_t scale) @@ -85,6 +87,8 @@ CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_f16(fp16x2_v a, fp32_t s // permute high bits and low bits to match the order of the original vector asm volatile("v_cvt_scalef32_pk_fp4_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(scale)); return bit_cast(bit_cast(c[0])[0]); +#else + return fp4x2_t{}; #endif } CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_bf16(bf16x2_v a, fp32_t scale) @@ -94,6 +98,8 @@ CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_bf16(bf16x2_v a, fp32_t // permute high bits and low bits to match the order of the original vector asm volatile("v_cvt_scalef32_pk_fp4_bf16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(scale)); return bit_cast(bit_cast(c[0])[0]); +#else + return fp4x2_t{}; #endif } diff --git a/csrc/include/custom_all_reduce.h b/csrc/include/custom_all_reduce.h index d8a370e514..aa2ba561bb 100644 --- a/csrc/include/custom_all_reduce.h +++ b/csrc/include/custom_all_reduce.h @@ -51,7 +51,7 @@ void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, const std::vector& offsets); -std::vector get_graph_buffer_ipc_meta(fptr_t _fa); +std::tuple get_graph_buffer_ipc_meta(fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector& handles, const std::vector& offsets); diff --git a/csrc/include/quant.h b/csrc/include/quant.h index bc783fa0b0..280cc3863a 100644 --- a/csrc/include/quant.h +++ b/csrc/include/quant.h @@ -2,7 +2,7 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include +#include namespace aiter { @@ -17,10 +17,10 @@ void dynamic_per_tensor_quant(torch::Tensor& out, // [..., d] void dynamic_per_token_scaled_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scales, - std::optional const& scale_ub, - bool shuffle_scale = false, - std::optional const& num_rows = std::nullopt, - int num_rows_factor = 1); + std::optional scale_ub = std::nullopt, + bool shuffle_scale = false, + std::optional num_rows = std::nullopt, + int num_rows_factor = 1); void dynamic_per_group_scaled_quant_fp4(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index eb3f7c9449..62f14ab1c2 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1,5 +1,9 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +namespace py = pybind11; #define ACTIVATION_PYBIND \ m.def("silu_and_mul", \ @@ -932,12 +936,12 @@ py::arg("sorted_weights") = std::nullopt); \ m.def("moe_sum", &aiter::moe_sum, "moe_sum(Tensor! input, Tensor output) -> ()"); -#define MOE_TOPK_PYBIND \ - m.def("topk_sigmoid", \ - &aiter::topk_sigmoid, \ - py::arg("topk_weights"), \ - py::arg("topk_indices"), \ - py::arg("gating_output"), \ +#define MOE_TOPK_PYBIND \ + m.def("topk_sigmoid", \ + &aiter::topk_sigmoid, \ + py::arg("topk_weights"), \ + py::arg("topk_indices"), \ + py::arg("gating_output"), \ "Apply topk sigmoid to the gating outputs."); #define MOE_SORTING_PYBIND \ @@ -1241,11 +1245,11 @@ "hipb_findallsols", \ py::arg("mat1"), \ py::arg("mat2"), \ - py::arg("bias") = std::nullopt, \ - py::arg("out_dtype") = std::nullopt, \ - py::arg("scaleA") = std::nullopt, \ - py::arg("scaleB") = std::nullopt, \ - py::arg("scaleC") = std::nullopt, \ + py::arg("bias") = std::nullopt, \ + py::arg("out_dtype") = std::nullopt, \ + py::arg("scaleA") = std::nullopt, \ + py::arg("scaleB") = std::nullopt, \ + py::arg("scaleC") = std::nullopt, \ py::arg("bpreshuffle") = false); \ m.def("getHipblasltKernelName", &getHipblasltKernelName); diff --git a/csrc/include/sample.h b/csrc/include/sample.h index 68de91fc53..74ed2d5816 100644 --- a/csrc/include/sample.h +++ b/csrc/include/sample.h @@ -2,7 +2,7 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include +#include namespace aiter { diff --git a/csrc/include/torch/mha_varlen_fwd.h b/csrc/include/torch/mha_varlen_fwd.h index e7e062c237..b9c0483102 100644 --- a/csrc/include/torch/mha_varlen_fwd.h +++ b/csrc/include/torch/mha_varlen_fwd.h @@ -5,7 +5,7 @@ namespace aiter { namespace torch_itfs { -std::vector +std::tuple mha_varlen_fwd(at::Tensor& q, // [total_q, hq, d] const at::Tensor& k, // [total_k, hk, d] const at::Tensor& v, // [total_k, hk, d] diff --git a/csrc/kernels/custom_all_reduce.cu b/csrc/kernels/custom_all_reduce.cu index 7845fae086..7fddb6ed45 100644 --- a/csrc/kernels/custom_all_reduce.cu +++ b/csrc/kernels/custom_all_reduce.cu @@ -303,7 +303,7 @@ void register_buffer(fptr_t _fa, fa->register_buffer(handles, offsets, t.data_ptr()); } -std::vector get_graph_buffer_ipc_meta(fptr_t _fa) +std::tuple get_graph_buffer_ipc_meta(fptr_t _fa) { auto fa = reinterpret_cast(_fa); auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index 685b1825c1..7b57572243 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -460,15 +460,15 @@ smooth_data_to_per_row_scale(const DTYPE_I* __restrict__ input, : (1. / ck_tile::type_convert(ck_tile::numeric::max())); const int32_t smscale_map_idx = smooth_scale_map == nullptr ? 0 : smooth_scale_map[blockIdx.x]; - const int64_t row_offset = token_idx * cols; - auto const* ptr_i = reinterpret_cast(input + row_offset); - auto const* input_vecs = reinterpret_cast(ptr_i); + const int64_t row_offset = token_idx * cols; + auto const* ptr_i = reinterpret_cast(input + row_offset); + auto const* input_vecs = reinterpret_cast(ptr_i); static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); const int32_t oob_i = (cols + ooba_i - 1) / ooba_i * ooba_i; auto buffer_i = ck_tile::make_buffer_view(ptr_i, oob_i); buffer_i.init_raw(); - auto const* ptr_smscale = reinterpret_cast(smooth_scale + smscale_map_idx * cols); + auto const* ptr_smscale = reinterpret_cast(smooth_scale + smscale_map_idx * cols); auto const* smscale_vecs = reinterpret_cast(ptr_smscale); auto buffer_s = ck_tile::make_buffer_view(ptr_smscale, cols); @@ -673,10 +673,10 @@ void dynamic_per_tensor_quant(torch::Tensor& out, // [..., d] void dynamic_per_token_scaled_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scales, - std::optional const& scale_ub, - bool shuffle_scale = false, - std::optional const& num_rows = std::nullopt, - int num_rows_factor = 1) + std::optional scale_ub = std::nullopt, + bool shuffle_scale = false, + std::optional num_rows = std::nullopt, + int num_rows_factor = 1) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); diff --git a/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu index 712f2e7791..7d386cdbf2 100644 --- a/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu @@ -326,32 +326,32 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, return args; } - -std::vector -mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] - const at::Tensor &k, // [total_k, hk, d] - const at::Tensor &v, // [total_k, hk, d] - const at::Tensor &cu_seqlens_q, // [b+1] - std::optional &cu_seqlens_k, // [b+1] - int max_seqlen_q, - int max_seqlen_k, - int min_seqlen_q, - float p_dropout, - float softmax_scale, - float logits_soft_cap, - bool zero_tensors, - bool is_causal, - int window_size_left, - int window_size_right, - bool return_softmax_lse, - bool return_dropout_randval, - std::optional out_, // [total_q, hq, d] - std::optional block_table_, // [hq] or [b, hq] - std::optional bias_, // [total_q, max_seqlen_k] - std::optional alibi_slopes_, // [hq] or [b, hq] - std::optional gen_, - std::optional cu_seqlens_q_padded_, // [b+1] physical starts with PAD - std::optional cu_seqlens_k_padded_) // [b+1] +std::tuple +mha_varlen_fwd( + at::Tensor& q, // [total_q, hq, d] + const at::Tensor& k, // [total_k, hk, d] + const at::Tensor& v, // [total_k, hk, d] + const at::Tensor& cu_seqlens_q, // [b+1] + std::optional& cu_seqlens_k, // [b+1] + int max_seqlen_q, + int max_seqlen_k, + int min_seqlen_q, + float p_dropout, + float softmax_scale, + float logits_soft_cap, + bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + bool return_softmax_lse, + bool return_dropout_randval, + std::optional out_, // [total_q, hq, d] + std::optional block_table_, // [hq] or [b, hq] + std::optional bias_, // [total_q, max_seqlen_k] + std::optional alibi_slopes_, // [hq] or [b, hq] + std::optional gen_, + std::optional cu_seqlens_q_padded_, // [b+1] physical starts with PAD + std::optional cu_seqlens_k_padded_) // [b+1] { auto q_dtype = q.scalar_type(); bool isQKVFp8 = q_dtype == at::ScalarType::Float8_e4m3fn || q_dtype == at::ScalarType::Float8_e4m3fnuz; diff --git a/csrc/pybind/moe_op_pybind.cu b/csrc/pybind/moe_op_pybind.cu index 4c62f61484..dfd2c62436 100644 --- a/csrc/pybind/moe_op_pybind.cu +++ b/csrc/pybind/moe_op_pybind.cu @@ -4,8 +4,4 @@ #include "moe_op.h" #include "rocm_ops.hpp" -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - AITER_ENUM_PYBIND; - MOE_OP_PYBIND; -} \ No newline at end of file +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { MOE_OP_PYBIND; } \ No newline at end of file