Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,5 @@ def getLogger():
from .ops.gradlib import *
from .ops.trans_ragged_layout import *
from .ops.sample import *
from .ops.fused_mrope_rms import *
from . import mla
13 changes: 13 additions & 0 deletions aiter/jit/optCompilerConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,19 @@
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_fused_mrope_rms": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/fused_mrope_rms_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/rope/rope_common.h'",
"f'{AITER_CSRC_DIR}/kernels/fused_mrope_rms.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_fmha_v3_fwd": {
"srcs": [
"f'{AITER_CSRC_DIR}/kernels/mha_common.cu'",
Expand Down
25 changes: 25 additions & 0 deletions aiter/ops/fused_mrope_rms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

from torch import Tensor
from ..jit.core import compile_ops
from typing import List


@compile_ops("module_fused_mrope_rms")
def fused_mrope_3d_rms(
qkv: Tensor,
qw: Tensor,
kw: Tensor,
cos_sin: Tensor,
positions: Tensor,
num_tokens: int,
num_heads_q: int,
num_heads_k: int,
num_heads_v: int,
head_size: int,
is_neox_style: bool,
mrope_section_: List[int],
is_interleaved: bool,
eps: float,
) -> None: ...
Loading