Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,13 @@ def check_args():
doc_str = op.__doc__.split("\n")[0]
doc_str = re.sub(r"<(.*?)\:.*?>", r"\g<1>", doc_str)
doc_str = doc_str.replace("list[", "List[")
doc_str = doc_str.replace("tuple[", "Tuple[")
doc_str = doc_str.replace("collections.abc.Sequence[", "List[")
doc_str = doc_str.replace("typing.SupportsInt", "int")
doc_str = doc_str.replace("typing.SupportsFloat", "float")
# A|None --> Optional[A]
pattern = r"([\w\.]+(?:\[[^\]]+\])?)\s*\|\s*None"
doc_str = re.sub(pattern, r"Optional[\1]", doc_str)
for el in enum_types:
doc_str = re.sub(f" aiter.*{el} ", f" {el} ", doc_str)
namespace = {
Expand All @@ -769,9 +776,7 @@ def check_args():
"torch": torch,
"typing": typing,
}
if sys.version_info < (3, 10):
pattern = r"([\w\.]+(?:\[[^\]]+\])?)\s*\|\s*None"
doc_str = re.sub(pattern, r"Optional[\1]", doc_str)

exec(
f"from aiter import*\ndef {doc_str}: pass",
namespace,
Expand Down
8 changes: 4 additions & 4 deletions aiter/jit/utils/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,7 +1551,7 @@ def _write_ninja_file_to_build_library(
# But we can't use this now because all aiter op based on torch
# which means pybind11 related build flags must from torch now
common_cflags = []
if torch_exclude and is_python_module:
if is_python_module:
import pybind11

extra_include_paths.append(pybind11.get_include())
Expand All @@ -1575,9 +1575,9 @@ def _write_ninja_file_to_build_library(
common_cflags.append(f"-DTORCH_EXTENSION_NAME={name}")
else:
common_cflags.append(f"-DTORCH_EXTENSION_NAME=aiter_")
common_cflags.append("-DTORCH_API_INCLUDE_EXTENSION_H")
common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
# common_cflags.append("-DTORCH_API_INCLUDE_EXTENSION_H")
# common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
# common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]

# Windows does not understand `-isystem` and quotes flags later.
common_cflags += [f"-I{shlex.quote(include)}" for include in user_includes]
Expand Down
4 changes: 2 additions & 2 deletions aiter/ops/custom_all_reduce.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

from typing import List, Optional
from typing import List, Optional, Tuple

import torch

Expand Down Expand Up @@ -184,7 +184,7 @@ def register_buffer(


@compile_ops("module_custom_all_reduce")
def get_graph_buffer_ipc_meta(_fa: int) -> tuple[torch.Tensor, torch.Tensor]: ...
def get_graph_buffer_ipc_meta(_fa: int) -> Tuple[torch.Tensor, torch.Tensor]: ...


@compile_ops("module_custom_all_reduce")
Expand Down
10 changes: 6 additions & 4 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

from torch import Tensor, Generator
from typing import Optional, Tuple, Any
from ..jit.core import compile_ops, CK_DIR, AITER_CSRC_DIR
from typing import Any, Optional, Tuple

import torch
from torch import Generator, Tensor

from ..jit.core import AITER_CSRC_DIR, CK_DIR, compile_ops
from ..jit.utils.chip_info import get_gfx
from ..jit.utils.torch_guard import torch_compile_guard
from ..utility import dtypes
import torch


def cmdGenFunc_mha_fwd(
Expand Down
10 changes: 6 additions & 4 deletions aiter/ops/norm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

from typing import Optional

import torch
from torch import Tensor
from typing import Optional

from ..jit.core import compile_ops

MD_NAME = "module_norm"
Expand Down Expand Up @@ -43,8 +45,8 @@ def layer_norm(
def layernorm2d_fwd(
input: Tensor,
# normalized_shape: List[int],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
weight: Tensor,
bias: Tensor,
epsilon: float = 1e-5,
x_bias: Optional[Tensor] = None,
) -> Tensor: ...
Expand Down
12 changes: 6 additions & 6 deletions csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"

#include "aiter_enum.h"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "py_itfs_common.h"
#include <hip/hip_runtime.h>
#include <torch/extension.h>
#include <torch/torch.h>

template <ck::index_t... Is>
Expand Down
7 changes: 4 additions & 3 deletions csrc/include/binary_operator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
* limitations under the License.
*/
#pragma once
#include <torch/all.h>
#include "dispatch_utils.h"
#include "hip_compat.h"
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include "hip_compat.h"
#include "dispatch_utils.h"
#include <torch/all.h>
#include <torch/extension.h>
#include <torch/torch.h>

#include <hip/hip_bf16.h>
Expand Down
6 changes: 6 additions & 0 deletions csrc/include/ck_tile/vec_convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_f32(fp32_t a, fp32_t b,
// permute high bits and low bits to match the order of the original vector
asm volatile("v_cvt_scalef32_pk_fp4_f32 %0, %1, %2, %3" : "=v"(c) : "v"(b), "v"(a), "v"(scale));
return bit_cast<fp4x2_t>(bit_cast<int8x2_t>(c[0])[0]);
#else
return fp4x2_t{};
#endif
}
CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_f16(fp16x2_v a, fp32_t scale)
Expand All @@ -85,6 +87,8 @@ CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_f16(fp16x2_v a, fp32_t s
// permute high bits and low bits to match the order of the original vector
asm volatile("v_cvt_scalef32_pk_fp4_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(scale));
return bit_cast<fp4x2_t>(bit_cast<int8x2_t>(c[0])[0]);
#else
return fp4x2_t{};
#endif
}
CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_bf16(bf16x2_v a, fp32_t scale)
Expand All @@ -94,6 +98,8 @@ CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_bf16(bf16x2_v a, fp32_t
// permute high bits and low bits to match the order of the original vector
asm volatile("v_cvt_scalef32_pk_fp4_bf16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(scale));
return bit_cast<fp4x2_t>(bit_cast<int8x2_t>(c[0])[0]);
#else
return fp4x2_t{};
#endif
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/include/custom_all_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void register_buffer(fptr_t _fa,
torch::Tensor& t,
const std::vector<torch::Tensor>& handles,
const std::vector<int64_t>& offsets);
std::vector<at::Tensor> get_graph_buffer_ipc_meta(fptr_t _fa);
std::tuple<torch::Tensor, torch::Tensor> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
const std::vector<torch::Tensor>& handles,
const std::vector<torch::Tensor>& offsets);
Expand Down
10 changes: 5 additions & 5 deletions csrc/include/quant.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

#include <torch/torch.h>
#include <torch/extension.h>

namespace aiter {

Expand All @@ -17,10 +17,10 @@ void dynamic_per_tensor_quant(torch::Tensor& out, // [..., d]
void dynamic_per_token_scaled_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales,
std::optional<at::Tensor> const& scale_ub,
bool shuffle_scale = false,
std::optional<at::Tensor> const& num_rows = std::nullopt,
int num_rows_factor = 1);
std::optional<torch::Tensor> scale_ub = std::nullopt,
bool shuffle_scale = false,
std::optional<torch::Tensor> num_rows = std::nullopt,
int num_rows_factor = 1);

void dynamic_per_group_scaled_quant_fp4(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
Expand Down
26 changes: 15 additions & 11 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

#include <pybind11/pybind11.h>
namespace py = pybind11;

#define ACTIVATION_PYBIND \
m.def("silu_and_mul", \
Expand Down Expand Up @@ -932,12 +936,12 @@
py::arg("sorted_weights") = std::nullopt); \
m.def("moe_sum", &aiter::moe_sum, "moe_sum(Tensor! input, Tensor output) -> ()");

#define MOE_TOPK_PYBIND \
m.def("topk_sigmoid", \
&aiter::topk_sigmoid, \
py::arg("topk_weights"), \
py::arg("topk_indices"), \
py::arg("gating_output"), \
#define MOE_TOPK_PYBIND \
m.def("topk_sigmoid", \
&aiter::topk_sigmoid, \
py::arg("topk_weights"), \
py::arg("topk_indices"), \
py::arg("gating_output"), \
"Apply topk sigmoid to the gating outputs.");

#define MOE_SORTING_PYBIND \
Expand Down Expand Up @@ -1241,11 +1245,11 @@
"hipb_findallsols", \
py::arg("mat1"), \
py::arg("mat2"), \
py::arg("bias") = std::nullopt, \
py::arg("out_dtype") = std::nullopt, \
py::arg("scaleA") = std::nullopt, \
py::arg("scaleB") = std::nullopt, \
py::arg("scaleC") = std::nullopt, \
py::arg("bias") = std::nullopt, \
py::arg("out_dtype") = std::nullopt, \
py::arg("scaleA") = std::nullopt, \
py::arg("scaleB") = std::nullopt, \
py::arg("scaleC") = std::nullopt, \
py::arg("bpreshuffle") = false); \
m.def("getHipblasltKernelName", &getHipblasltKernelName);

Expand Down
2 changes: 1 addition & 1 deletion csrc/include/sample.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

#include <torch/torch.h>
#include <torch/extension.h>

namespace aiter {

Expand Down
2 changes: 1 addition & 1 deletion csrc/include/torch/mha_varlen_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace aiter {
namespace torch_itfs {
std::vector<at::Tensor>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
mha_varlen_fwd(at::Tensor& q, // [total_q, hq, d]
const at::Tensor& k, // [total_k, hk, d]
const at::Tensor& v, // [total_k, hk, d]
Expand Down
2 changes: 1 addition & 1 deletion csrc/kernels/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ void register_buffer(fptr_t _fa,
fa->register_buffer(handles, offsets, t.data_ptr());
}

std::vector<at::Tensor> get_graph_buffer_ipc_meta(fptr_t _fa)
std::tuple<torch::Tensor, torch::Tensor> get_graph_buffer_ipc_meta(fptr_t _fa)
{
auto fa = reinterpret_cast<aiter::CustomAllreduce*>(_fa);
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
Expand Down
16 changes: 8 additions & 8 deletions csrc/kernels/quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -460,15 +460,15 @@ smooth_data_to_per_row_scale(const DTYPE_I* __restrict__ input,
: (1. / ck_tile::type_convert<float>(ck_tile::numeric<DTYPE_O>::max()));

const int32_t smscale_map_idx = smooth_scale_map == nullptr ? 0 : smooth_scale_map[blockIdx.x];
const int64_t row_offset = token_idx * cols;
auto const* ptr_i = reinterpret_cast<DTYPE_I const*>(input + row_offset);
auto const* input_vecs = reinterpret_cast<vec_i const*>(ptr_i);
const int64_t row_offset = token_idx * cols;
auto const* ptr_i = reinterpret_cast<DTYPE_I const*>(input + row_offset);
auto const* input_vecs = reinterpret_cast<vec_i const*>(ptr_i);
static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I);
const int32_t oob_i = (cols + ooba_i - 1) / ooba_i * ooba_i;
auto buffer_i = ck_tile::make_buffer_view<ck_tile::address_space_enum::global>(ptr_i, oob_i);
buffer_i.init_raw();

auto const* ptr_smscale = reinterpret_cast<float const*>(smooth_scale + smscale_map_idx * cols);
auto const* ptr_smscale = reinterpret_cast<float const*>(smooth_scale + smscale_map_idx * cols);
auto const* smscale_vecs = reinterpret_cast<vec_s const*>(ptr_smscale);
auto buffer_s =
ck_tile::make_buffer_view<ck_tile::address_space_enum::global>(ptr_smscale, cols);
Expand Down Expand Up @@ -673,10 +673,10 @@ void dynamic_per_tensor_quant(torch::Tensor& out, // [..., d]
void dynamic_per_token_scaled_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales,
std::optional<at::Tensor> const& scale_ub,
bool shuffle_scale = false,
std::optional<at::Tensor> const& num_rows = std::nullopt,
int num_rows_factor = 1)
std::optional<torch::Tensor> scale_ub = std::nullopt,
bool shuffle_scale = false,
std::optional<torch::Tensor> num_rows = std::nullopt,
int num_rows_factor = 1)
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
Expand Down
52 changes: 26 additions & 26 deletions csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -326,32 +326,32 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,
return args;
}


std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d]
const at::Tensor &k, // [total_k, hk, d]
const at::Tensor &v, // [total_k, hk, d]
const at::Tensor &cu_seqlens_q, // [b+1]
std::optional<const at::Tensor> &cu_seqlens_k, // [b+1]
int max_seqlen_q,
int max_seqlen_k,
int min_seqlen_q,
float p_dropout,
float softmax_scale,
float logits_soft_cap,
bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
bool return_softmax_lse,
bool return_dropout_randval,
std::optional<at::Tensor> out_, // [total_q, hq, d]
std::optional<const at::Tensor> block_table_, // [hq] or [b, hq]
std::optional<const at::Tensor> bias_, // [total_q, max_seqlen_k]
std::optional<const at::Tensor> alibi_slopes_, // [hq] or [b, hq]
std::optional<at::Generator> gen_,
std::optional<const at::Tensor> cu_seqlens_q_padded_, // [b+1] physical starts with PAD
std::optional<const at::Tensor> cu_seqlens_k_padded_) // [b+1]
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
mha_varlen_fwd(
at::Tensor& q, // [total_q, hq, d]
const at::Tensor& k, // [total_k, hk, d]
const at::Tensor& v, // [total_k, hk, d]
const at::Tensor& cu_seqlens_q, // [b+1]
std::optional<const at::Tensor>& cu_seqlens_k, // [b+1]
int max_seqlen_q,
int max_seqlen_k,
int min_seqlen_q,
float p_dropout,
float softmax_scale,
float logits_soft_cap,
bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
bool return_softmax_lse,
bool return_dropout_randval,
std::optional<at::Tensor> out_, // [total_q, hq, d]
std::optional<const at::Tensor> block_table_, // [hq] or [b, hq]
std::optional<const at::Tensor> bias_, // [total_q, max_seqlen_k]
std::optional<const at::Tensor> alibi_slopes_, // [hq] or [b, hq]
std::optional<at::Generator> gen_,
std::optional<const at::Tensor> cu_seqlens_q_padded_, // [b+1] physical starts with PAD
std::optional<const at::Tensor> cu_seqlens_k_padded_) // [b+1]
{
auto q_dtype = q.scalar_type();
bool isQKVFp8 = q_dtype == at::ScalarType::Float8_e4m3fn || q_dtype == at::ScalarType::Float8_e4m3fnuz;
Expand Down
6 changes: 1 addition & 5 deletions csrc/pybind/moe_op_pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,4 @@
#include "moe_op.h"
#include "rocm_ops.hpp"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
AITER_ENUM_PYBIND;
MOE_OP_PYBIND;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { MOE_OP_PYBIND; }