diff --git a/aiter/__init__.py b/aiter/__init__.py index e033e382d8..edfbaf06e8 100644 --- a/aiter/__init__.py +++ b/aiter/__init__.py @@ -76,4 +76,5 @@ def getLogger(): from .ops.gradlib import * from .ops.trans_ragged_layout import * from .ops.sample import * +from .ops.fused_mrope_rms import * from . import mla diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index 51b271f00f..f5b35208e5 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -716,6 +716,19 @@ "verbose": "False", "blob_gen_cmd": "''" }, + "module_fused_mrope_rms": { + "srcs": [ + "f'{AITER_CSRC_DIR}/pybind/fused_mrope_rms_pybind.cu'", + "f'{AITER_CSRC_DIR}/kernels/rope/rope_common.h'", + "f'{AITER_CSRC_DIR}/kernels/fused_mrope_rms.cu'" + ], + "flags_extra_cc": [], + "flags_extra_hip": [], + "extra_ldflags": "None", + "extra_include": [], + "verbose": "False", + "blob_gen_cmd": "''" + }, "module_fmha_v3_fwd": { "srcs": [ "f'{AITER_CSRC_DIR}/kernels/mha_common.cu'", diff --git a/aiter/ops/fused_mrope_rms.py b/aiter/ops/fused_mrope_rms.py new file mode 100644 index 0000000000..a337d032ed --- /dev/null +++ b/aiter/ops/fused_mrope_rms.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from torch import Tensor +from ..jit.core import compile_ops +from typing import List + + +@compile_ops("module_fused_mrope_rms") +def fused_mrope_3d_rms( + qkv: Tensor, + qw: Tensor, + kw: Tensor, + cos_sin: Tensor, + positions: Tensor, + num_tokens: int, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_size: int, + is_neox_style: bool, + mrope_section_: List[int], + is_interleaved: bool, + eps: float, +) -> None: ... diff --git a/aiter/rotary_embedding.py b/aiter/rotary_embedding.py index 1b328b1c3c..ae4c67467f 100644 --- a/aiter/rotary_embedding.py +++ b/aiter/rotary_embedding.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn -from aiter import dtypes +from aiter import dtypes, fused_mrope_3d_rms # from custom_op import CustomOp @@ -1144,6 +1144,334 @@ def get_next_input_positions( ] +class MRotaryEmbeddingQKNormFused(nn.Module): + """Rotary Embedding with Multimodal Sections fused with QKNorm""" + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=dtypes.fp32) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=dtypes.fp32) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + return cos, sin + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[List[int]] = None, + mrope_interleaved: bool = False, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + assert self.head_size == self.rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + self.mrope_interleaved = mrope_interleaved + + cos, sin = self._compute_cos_sin_cache() + cos = cos.to(dtype) + sin = sin.to(dtype) + cache = torch.cat((cos, sin), dim=-1) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + self.mrope_section = mrope_section + if self.mrope_section: + expected_sum = rotary_dim // 2 + actual_sum = sum(self.mrope_section) + if actual_sum != expected_sum: + print( + f"MRoPE section sum mismatch: expected {expected_sum}, got {actual_sum}. " + f"Adjusting mrope_section to match rotary_dim // 2 = {expected_sum}" + ) + # Auto-correct by scaling the mrope_section proportionally + if actual_sum > 0: + scale_factor = expected_sum / actual_sum + self.mrope_section = [ + max(1, int(section * scale_factor)) + for section in self.mrope_section + ] + # Ensure the sum exactly matches by adjusting the last element + current_sum = sum(self.mrope_section) + if current_sum != expected_sum: + self.mrope_section[-1] += expected_sum - current_sum + else: + # If all sections are 0, create a default distribution + self.mrope_section = [ + expected_sum // len(self.mrope_section) + ] * len(self.mrope_section) + # Handle remainder + remainder = expected_sum % len(self.mrope_section) + for i in range(remainder): + self.mrope_section[i] += 1 + + print( + f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" + ) + + def forward( + self, + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + positions: torch.Tensor, + num_heads: int, + num_kv_heads: int, + eps: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert positions.ndim == 1 or positions.ndim == 2 + num_tokens = positions.shape[-1] + num_heads_q = num_heads + num_heads_k = num_kv_heads + num_heads_v = num_kv_heads + is_interleaved = ( + True if positions.ndim == 2 and self.mrope_section is not None else False + ) + assert is_interleaved == self.mrope_interleaved + fused_mrope_3d_rms( + qkv, + q_weight, + k_weight, + self.cos_sin_cache, + positions, + num_tokens, + num_heads_q, + num_heads_k, + num_heads_v, + self.head_size, + self.is_neox_style, + self.mrope_section, + is_interleaved, + eps, + ) + q_size = num_heads_q * self.head_size + k_size = num_heads_k * self.head_size + v_size = num_heads_v * self.head_size + + qkv = qkv.view(num_tokens, q_size + k_size + v_size) + q, k, v = qkv.split([q_size, k_size, v_size], dim=-1) + return q, k, v + + +class DualChunkRotaryEmbedding(nn.Module): + """Rotary positional embedding for Dual Chunk Attention.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + chunk_size: int, + local_size: int, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.chunk_size = chunk_size + self.local_size = local_size + self.dtype = dtype + self.device = torch.device(f"cuda:{torch.cuda.current_device()}") + (q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = ( + self._compute_cos_sin_cache() + ) + + self.register_buffer("cos_sin_q_cache", q_cache, persistent=False) + self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False) + self.register_buffer("cos_sin_k_cache", k_cache, persistent=False) + self.register_buffer( + "cos_sin_qc_no_clamp_cache", qc_no_clamp_cache, persistent=False + ) + self.register_buffer("cos_sin_q_inter_cache", q_inter_cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. + # However, we use `torch.arange(..., dtype=torch.float)` instead to + # avoid numerical issues with large base values (e.g., 10000000). + # This may cause a slight numerical difference between the HF + # implementation and ours. + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + chunk_len = self.chunk_size - self.local_size + q_t = torch.arange(chunk_len, dtype=torch.float) + qc_t = (torch.arange(chunk_len, dtype=torch.float) + chunk_len).clamp( + max=self.chunk_size + ) + k_t = torch.arange(self.max_position_embeddings, dtype=torch.float) % chunk_len + + # count from chunk_len, no clamp(self.chunk_size) restriction + qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len + # count from self.chunk_size for q_inter's rope + q_inter_t = torch.arange(chunk_len, dtype=torch.float) + self.chunk_size + + q_freqs = torch.outer(q_t, inv_freq) + qc_freqs = torch.outer(qc_t, inv_freq) + k_freqs = torch.outer(k_t, inv_freq) + qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq) + q_inter_freqs = torch.outer(q_inter_t, inv_freq) + + q_cos = q_freqs.cos() + q_sin = q_freqs.sin() + qc_cos = qc_freqs.cos() + qc_sin = qc_freqs.sin() + k_cos = k_freqs.cos() + k_sin = k_freqs.sin() + + qc_no_clamp_cos = qc_no_clamp_freqs.cos() + qc_no_clamp_sin = qc_no_clamp_freqs.sin() + q_inter_cos = q_inter_freqs.cos() + q_inter_sin = q_inter_freqs.sin() + + q_cache = torch.cat((q_cos, q_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + k_cache = torch.cat((k_cos, k_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] + else: + query_pass = None + key_pass = None + + positions_with_offsets = ( + torch.add(positions, offsets) if offsets is not None else positions + ) + key = self._apply_rotary_embedding( + self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass + ) + chunk_len = self.chunk_size - self.local_size + query = self._apply_rotary_embedding( + self.cos_sin_q_cache[positions_with_offsets % chunk_len], + query_rot, + query_pass, + ) + query_succ = self._apply_rotary_embedding( + self.cos_sin_qc_cache[positions_with_offsets % chunk_len], + query_rot, + query_pass, + ) + query_inter = self._apply_rotary_embedding( + self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1), + query_rot, + query_pass, + ) + query_succ_critical = self._apply_rotary_embedding( + self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len], + query_rot, + query_pass, + ) + query_inter_critical = self._apply_rotary_embedding( + self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len], + query_rot, + query_pass, + ) + + # merge query into one tensor to simplify the interfaces + query = torch.cat( + ( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ), + dim=-1, + ) + return query, key + + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin + + if self.rotary_dim < self.head_size: + hidden = torch.cat((hidden_rot, hidden_pass), dim=-1) + else: + hidden = hidden_rot + return hidden.flatten(-2).squeeze(0) + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + s += f", chunk_size={self.chunk_size}, local_size={self.local_size}" + return s + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -1156,6 +1484,7 @@ def get_rope( rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() @@ -1167,6 +1496,17 @@ def get_rope( rope_scaling_args = tuple(rope_scaling_tuple.items()) else: rope_scaling_args = None + + if dual_chunk_attention_config is not None: + dual_chunk_attention_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in dual_chunk_attention_config.items() + if k != "sparse_attention_config" + } + dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) + else: + dual_chunk_attention_args = None + if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) key = ( @@ -1176,12 +1516,28 @@ def get_rope( base, is_neox_style, rope_scaling_args, + dual_chunk_attention_args, dtype, ) if key in _ROPE_DICT: return _ROPE_DICT[key] - if rope_scaling is None: + if dual_chunk_attention_config is not None: + extra_kwargs = { + k: v + for k, v in dual_chunk_attention_config.items() + if k in ("chunk_size", "local_size") + } + rotary_emb = DualChunkRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + **extra_kwargs, + ) + elif rope_scaling is None: rotary_emb = RotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, dtype ) @@ -1211,6 +1567,34 @@ def get_rope( high_freq_factor, original_max_position, ) + elif scaling_type == "default": + if ( + "mrope_section" in rope_scaling + and "aiter_rope_fused_qknorm" in rope_scaling + ): + rotary_emb = MRotaryEmbeddingQKNormFused( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=( + rope_scaling["mrope_interleaved"] + if "mrope_interleaved" in rope_scaling + else False + ), + ) + else: + rotary_emb = RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding( head_size, diff --git a/csrc/include/fused_mrope_rms.h b/csrc/include/fused_mrope_rms.h new file mode 100644 index 0000000000..cca257efc7 --- /dev/null +++ b/csrc/include/fused_mrope_rms.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +using namespace at; + +void fused_mrope_3d_rms(Tensor& qkv, + Tensor& qw, + Tensor& kw, + Tensor& cos_sin, + Tensor& positions, + int64_t num_tokens, + int64_t num_heads_q, + int64_t num_heads_k, + int64_t num_heads_v, + int64_t head_size, + bool is_neox_style, + std::vector mrope_section_, + bool is_interleaved, + double eps); diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 8646633588..5ba6be5b99 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1264,6 +1264,8 @@ namespace py = pybind11; m.def("rope_cached_positions_offsets_fwd_impl", &rope_cached_positions_offsets_fwd_impl); \ m.def("rope_cached_positions_offsets_2c_fwd_impl", &rope_cached_positions_offsets_2c_fwd_impl); +#define FUSED_MROPE_RMS_PYBIND m.def("fused_mrope_3d_rms", &fused_mrope_3d_rms); + #define SMOOTHQUANT_PYBIND \ m.def("smoothquant_fwd", &smoothquant_fwd); \ m.def("moe_smoothquant_fwd", &moe_smoothquant_fwd); diff --git a/csrc/kernels/fused_mrope_rms.cu b/csrc/kernels/fused_mrope_rms.cu new file mode 100644 index 0000000000..080d91a689 --- /dev/null +++ b/csrc/kernels/fused_mrope_rms.cu @@ -0,0 +1,404 @@ +#include "rope/rope_common.h" + +using namespace at; + +namespace rope_rms { + +static constexpr int kBytesPerAccess = 16; + +namespace block_utils { + +template +__inline__ __device__ T warp_shfl_xor_sync(T val, int offset) { + return __shfl_xor(val, offset, 32); +} + +template +__inline__ __device__ T warp_reduce_sum(T val) { +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) + val += warp_shfl_xor_sync(val, offset); + return val; +} + +template +__inline__ __device__ T warp_shfl_sync(T val, int src_id) { + return __shfl_sync(__activemask(), val, src_id, 32); +} + +} // namespace block_utils + +template +struct alignas(sizeof(T) * vec_size) vec_t { + T data[vec_size]; + __device__ __forceinline__ T &operator[](int i) { + return data[i]; + } + __device__ __forceinline__ T const &operator[](int i) const { + return data[i]; + } + __device__ __forceinline__ void load(const T *ptr) { + *this = *reinterpret_cast *>(const_cast(ptr)); + } + __device__ __forceinline__ void loop_load(const T *ptr) { +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + data[i] = ptr[i]; + } + } + __device__ __forceinline__ void store(T *ptr) { + *reinterpret_cast *>(ptr) = *this; + } + __device__ __forceinline__ void loop_store(T *ptr) { +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + ptr[i] = data[i]; + } + } + __device__ __forceinline__ void nontemporal_load(const T *ptr) { + constexpr int ITERS = vec_size * sizeof(T) / sizeof(uint32_t); +#pragma unroll + for (int i = 0; i < ITERS; ++i) { + reinterpret_cast(&data)[i] = __builtin_nontemporal_load((uint32_t *)ptr + i); + } + } + __device__ __forceinline__ void nontemporal_store(T *ptr) { + constexpr int ITERS = vec_size * sizeof(T) / sizeof(uint32_t); +#pragma unroll + for (int i = 0; i < ITERS; ++i) { + __builtin_nontemporal_store(reinterpret_cast(&data)[i], (uint32_t *)ptr + i); + } + } + __device__ __forceinline__ void fill(T val) { +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + data[i] = val; + } + } +}; + +template +__inline__ __device__ vec_t warp_shfl_sync_vec(vec_t &val, int offset) { + constexpr int ITERS = vec_size * sizeof(T) / sizeof(uint32_t); + vec_t out; +#pragma unroll + for (int i = 0; i < ITERS; ++i) { + uint32_t val_ = reinterpret_cast(&val)[i]; + reinterpret_cast(&out)[i] = block_utils::warp_shfl_sync(val_, offset); + } + return out; +} + +template +__device__ __forceinline__ void warp_rms_norm_( + vec_t &input, + vec_t &gamma, + float rms_dim, + float rms_eps) { + vec_t norm_out; + float acc = 0.f; +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + float v = (float)input[i]; + acc += v * v; + } + int warp_id = threadIdx.x / 32; + int warp_t_id = threadIdx.x % 32; + acc = block_utils::warp_reduce_sum(acc); + acc = block_utils::warp_shfl_sync(acc, 0); + __syncwarp(); + auto s_val = rsqrtf(acc / rms_dim + rms_eps); +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + input[i] = static_cast((float)input[i] * s_val * (float)gamma[i]); + } +} + +template +__device__ __forceinline__ void mrope_load_cos_sin_vec(vec_t &out, + const T *cos_sin, const int64_t *positions, int64_t ps0, int64_t ps1, + int64_t token_id, int64_t num_tokens, + int access_id_in_head, std::array &mrope_section) { + constexpr int HALF_HEAD_SIZE = HEAD_SIZE / 2; + if constexpr (IS_INTERLEAVED) { + for (int i = 0; i < VEC_SIZE; ++i) { + auto id = access_id_in_head + i; + auto id_ = (access_id_in_head < HALF_HEAD_SIZE) ? id : id - HALF_HEAD_SIZE; + auto mid_ = id_ % M; + if (mid_ >= 1 && id_ < mrope_section[mid_] * M) { + auto p = positions[mid_ * ps0 + token_id * ps1]; + out[i] = cos_sin[p * HEAD_SIZE + id]; + } else { + out[i] = cos_sin[positions[token_id * ps1] * HEAD_SIZE + id]; + } + } + } else { + for (int i = 0; i < VEC_SIZE; ++i) { + auto id = access_id_in_head + i; + auto id_ = (access_id_in_head < HALF_HEAD_SIZE) ? id : id - HALF_HEAD_SIZE; + int mid; + int end = 0; + for (mid = 0; mid < M; ++mid) { + end += mrope_section[mid]; + if (id_ < end) + break; + } + auto p = positions[mid * ps0 + token_id * ps1]; + out[i] = cos_sin[p * HEAD_SIZE + id]; + } + } +} + +template +__global__ void fused_mrope_rms_neox_kernel( + T *qkv, const T *q_w, const T *k_w, const T *cos_sin, const int64_t *positions, int64_t ps0, int64_t ps1, + int64_t num_heads_q, int64_t num_heads_k, int64_t num_heads_v, double eps, + std::array mrope_section, int64_t num_tokens, int64_t total_warps) { + constexpr int VEC_SIZE = HEAD_SIZE / 32; + constexpr int HALF_HEAD_SIZE = HEAD_SIZE / 2; + const auto warp_id = threadIdx.x / 32; + const auto num_warps_per_block = blockDim.x / 32; + const auto global_warp_id = blockIdx.x * num_warps_per_block + warp_id; + if (global_warp_id >= total_warps) { + return; + } + auto token_id = global_warp_id / (num_heads_q + num_heads_k); + auto head_id_in_token = global_warp_id % (num_heads_q + num_heads_k); + bool is_q = head_id_in_token < num_heads_q; + auto access_id_in_head = (threadIdx.x % 32) * VEC_SIZE; + auto neighbor_offset = access_id_in_head < HALF_HEAD_SIZE ? HALF_HEAD_SIZE / VEC_SIZE : -HALF_HEAD_SIZE / VEC_SIZE; + auto qkv_ = qkv + token_id * (num_heads_q + num_heads_k + num_heads_v) * HEAD_SIZE + head_id_in_token * HEAD_SIZE; + + vec_t w_vec; + + if (is_q) { + w_vec.load(q_w + access_id_in_head); + } else { + w_vec.load(k_w + access_id_in_head); + } + + vec_t x_vec, cos_sin_vec; + x_vec.load(qkv_ + access_id_in_head); + if constexpr (IS_MROPE) { + mrope_load_cos_sin_vec( + cos_sin_vec, cos_sin, positions, ps0, ps1, token_id, num_tokens, access_id_in_head, mrope_section); + } else { + auto position_ = positions[token_id * ps1]; + cos_sin_vec.load(&cos_sin[position_ * HEAD_SIZE + access_id_in_head]); + } + + warp_rms_norm_(x_vec, w_vec, HEAD_SIZE, eps); + auto nb_cos_sin_vec = warp_shfl_sync_vec(cos_sin_vec, threadIdx.x + neighbor_offset); + auto nb_x_vec = warp_shfl_sync_vec(x_vec, threadIdx.x + neighbor_offset); + vec_t out_vec; + if (neighbor_offset > 0) { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + out_vec[i] = (float)x_vec[i] * (float)cos_sin_vec[i] - (float)nb_x_vec[i] * (float)nb_cos_sin_vec[i]; // x0 * cos - x1 * sin + } + } else { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + out_vec[i] = (float)x_vec[i] * (float)nb_cos_sin_vec[i] + (float)nb_x_vec[i] * (float)cos_sin_vec[i]; // x1 * cos + x0 * sin + } + } + out_vec.store(qkv_ + access_id_in_head); +} + +template +__global__ void fused_mrope_rms_noneox_kernel( + T *qkv, const T *q_w, const T *k_w, const T *cos_sin, const int64_t *positions, int64_t ps0, int64_t ps1, + int64_t num_heads_q, int64_t num_heads_k, int64_t num_heads_v, double eps, + std::array mrope_section, int64_t num_tokens, int64_t total_warps) { + constexpr int VEC_SIZE = HEAD_SIZE / 32; + constexpr int HALF_HEAD_SIZE = HEAD_SIZE / 2; + const auto warp_id = threadIdx.x / 32; + const auto num_warps_per_block = blockDim.x / 32; + const auto global_warp_id = blockIdx.x * num_warps_per_block + warp_id; + if (global_warp_id >= total_warps) { + return; + } + auto token_id = global_warp_id / (num_heads_q + num_heads_k); + auto head_id_in_token = global_warp_id % (num_heads_q + num_heads_k); + bool is_q = head_id_in_token < num_heads_q; + auto access_id_in_head = (threadIdx.x % 32) * VEC_SIZE; + auto qkv_ = qkv + token_id * (num_heads_q + num_heads_k + num_heads_v) * HEAD_SIZE + head_id_in_token * HEAD_SIZE; + + vec_t w_vec; + + if (is_q) { + w_vec.load(q_w + access_id_in_head); + } else { + w_vec.load(k_w + access_id_in_head); + } + + vec_t x_vec, cos_vec, sin_vec; + x_vec.load(qkv_ + access_id_in_head); + if constexpr (IS_MROPE) { + mrope_load_cos_sin_vec( + cos_vec, cos_sin, positions, ps0, ps1, token_id, num_tokens, access_id_in_head / 2, mrope_section); + mrope_load_cos_sin_vec( + sin_vec, cos_sin, positions, ps0, ps1, token_id, num_tokens, access_id_in_head / 2 + HALF_HEAD_SIZE, mrope_section); + } else { + auto position_ = positions[token_id * ps1]; + cos_vec.load(&cos_sin[position_ * HEAD_SIZE + access_id_in_head / 2]); + sin_vec.load(&cos_sin[position_ * HEAD_SIZE + access_id_in_head / 2 + HALF_HEAD_SIZE]); + } + + warp_rms_norm_(x_vec, w_vec, HEAD_SIZE, eps); + + vec_t out_vec; +#pragma unroll + for (int i = 0; i < VEC_SIZE / 2; ++i) { + out_vec[2 * i + 0] = (float)x_vec[2 * i + 0] * (float)cos_vec[i] - (float)x_vec[2 * i + 1] * (float)sin_vec[i]; + out_vec[2 * i + 1] = (float)x_vec[2 * i + 1] * (float)cos_vec[i] + (float)x_vec[2 * i + 0] * (float)sin_vec[i]; + } + + out_vec.store(qkv_ + access_id_in_head); +} + +template +void fused_rope_rms( + T *qkv, const T *q_w, const T *k_w, const T *cos_sin, const int64_t *positions, + int64_t num_tokens, int64_t num_heads_q, int64_t num_heads_k, int64_t num_heads_v, int64_t head_size, + bool is_neox_style, double eps, hipStream_t stream) { + TORCH_CHECK(head_size == 64 || head_size == 128 || head_size == 256); + constexpr int block_size = 256; + auto total_warps = num_tokens * (num_heads_q + num_heads_k); + auto num_warps_per_block = block_size / 32; + dim3 threadsPerBlock(block_size); + dim3 numBlocks((total_warps + num_warps_per_block - 1) / num_warps_per_block); + std::array mrope_section = {0}; + +#define DISPATCH_NEOX(HEAD_SIZE) \ + if (is_neox_style) { \ + fused_mrope_rms_neox_kernel<<>>( \ + qkv, q_w, k_w, cos_sin, positions, num_heads_q, num_heads_k, num_heads_v, eps, mrope_section, num_tokens, total_warps); \ + } else { \ + fused_mrope_rms_noneox_kernel<<>>( \ + qkv, q_w, k_w, cos_sin, positions, num_heads_q, num_heads_k, num_heads_v, eps, mrope_section, num_tokens, total_warps); \ + } + + switch (head_size) { + case 64: + DISPATCH_NEOX(64) + break; + case 128: + DISPATCH_NEOX(128) + break; + case 256: + DISPATCH_NEOX(256) + break; + } + +#undef DISPATCH_NEOX +} + +template +void fused_mrope_rms( + T *qkv, const T *q_w, const T *k_w, const T *cos_sin, const int64_t *positions, int64_t ps0, int64_t ps1, + int64_t num_tokens, int64_t num_heads_q, int64_t num_heads_k, int64_t num_heads_v, int64_t head_size, + bool is_neox_style, double eps, std::array mrope_section, bool is_interleaved, hipStream_t stream) { + TORCH_CHECK(head_size == 64 || head_size == 128 || head_size == 256); + auto dim = std::accumulate(mrope_section.begin(), mrope_section.end(), 0); + TORCH_CHECK(dim == head_size / 2); + constexpr int block_size = 256; + auto total_warps = num_tokens * (num_heads_q + num_heads_k); + auto num_warps_per_block = block_size / 32; + dim3 threadsPerBlock(block_size); + dim3 numBlocks((total_warps + num_warps_per_block - 1) / num_warps_per_block); + +#define DISPATCH_NEOX(HEAD_SIZE, IS_INTERLEAVED) \ + if (is_neox_style) { \ + fused_mrope_rms_neox_kernel<<>>( \ + qkv, q_w, k_w, cos_sin, positions, ps0, ps1, num_heads_q, num_heads_k, num_heads_v, eps, mrope_section, num_tokens, total_warps); \ + } else { \ + fused_mrope_rms_noneox_kernel<<>>( \ + qkv, q_w, k_w, cos_sin, positions, ps0, ps1, num_heads_q, num_heads_k, num_heads_v, eps, mrope_section, num_tokens, total_warps); \ + } + + if (is_interleaved) { + switch (head_size) { + case 64: + DISPATCH_NEOX(64, true) + break; + case 128: + DISPATCH_NEOX(128, true) + break; + case 256: + DISPATCH_NEOX(256, true) + break; + } + } else { + switch (head_size) { + case 64: + DISPATCH_NEOX(64, false) + break; + case 128: + DISPATCH_NEOX(128, false) + break; + case 256: + DISPATCH_NEOX(256, false) + break; + } + } + +#undef DISPATCH_NEOX +} + +} // namespace rope_rms + +template +struct KernelElementType { + using type = T; +}; + +template <> +struct KernelElementType { + using type = __half; +}; + +template <> +struct KernelElementType { + using type = hip_bfloat16; +}; + +void fused_mrope_3d_rms(Tensor &qkv, Tensor &qw, Tensor &kw, Tensor &cos_sin, Tensor &positions, + int64_t num_tokens, int64_t num_heads_q, int64_t num_heads_k, int64_t num_heads_v, int64_t head_size, + bool is_neox_style, std::vector mrope_section_, bool is_interleaved, double eps) { + TORCH_CHECK(mrope_section_.size() == 3); + TORCH_CHECK(qkv.is_contiguous() && qw.is_contiguous() && kw.is_contiguous() && cos_sin.is_contiguous()); + std::array mrope_section; + mrope_section[0] = mrope_section_[0]; + mrope_section[1] = mrope_section_[1]; + mrope_section[2] = mrope_section_[2]; + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(qkv)); + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + auto pos_strides = positions.strides(); + TORCH_CHECK(pos_strides.size() == 2); + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, + kHalf, + qkv.scalar_type(), + "fused_mrope_3d_rms", [&] { + using T = KernelElementType::type; + rope_rms::fused_mrope_rms( + (T*)qkv.data_ptr(), + (T*)qw.data_ptr(), + (T*)kw.data_ptr(), + (T*)cos_sin.data_ptr(), + positions.data_ptr(), + pos_strides[0], + pos_strides[1], + num_tokens, + num_heads_q, + num_heads_k, + num_heads_v, + head_size, + is_neox_style, + eps, + mrope_section, + is_interleaved, + stream); + }); +} diff --git a/csrc/pybind/fused_mrope_rms_pybind.cu b/csrc/pybind/fused_mrope_rms_pybind.cu new file mode 100644 index 0000000000..d32094f7b4 --- /dev/null +++ b/csrc/pybind/fused_mrope_rms_pybind.cu @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "rocm_ops.hpp" +#include "fused_mrope_rms.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + FUSED_MROPE_RMS_PYBIND; +} diff --git a/csrc/rocm_ops.cpp b/csrc/rocm_ops.cpp index 9b21b5b2cd..71395bd5ee 100644 --- a/csrc/rocm_ops.cpp +++ b/csrc/rocm_ops.cpp @@ -19,6 +19,7 @@ #include "custom.h" #include "custom_all_reduce.h" #include "deepgemm.h" +#include "fused_mrope_rms.h" #include "gemm_a4w4_blockscale.h" #include "gemm_a8w8.h" #include "gemm_a8w8_blockscale.h" @@ -95,6 +96,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ROPE_GENERAL_FWD_PYBIND; ROPE_GENERAL_BWD_PYBIND; ROPE_POS_FWD_PYBIND; + FUSED_MROPE_RMS_PYBIND; // GEMM_A8W8_BLOCKSCALE_TUNE_PYBIND; GEMM_A4W4_BLOCKSCALE_PYBIND; GEMM_A8W8_BLOCKSCALE_PYBIND; diff --git a/op_tests/test_fused_mrope_rms.py b/op_tests/test_fused_mrope_rms.py new file mode 100644 index 0000000000..4e212b864d --- /dev/null +++ b/op_tests/test_fused_mrope_rms.py @@ -0,0 +1,266 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +from torch import Tensor +import aiter +from aiter.test_common import checkAllclose, perftest, benchmark +from typing import List + + +def rms_norm_forward(x: Tensor, weight: Tensor, eps: float): + input_dtype = x.dtype + variance = x.float().pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x.to(input_dtype) + return weight * x + + +def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + """ + x_t = x[0].clone() + x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3] + x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3] + return x_t + + +def apply_rotary_emb_torch( + x: Tensor, + cos: Tensor, + sin: Tensor, + is_neox_style: bool, +) -> Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +def apply_rotary_emb_dispatch( + x: Tensor, cos: Tensor, sin: Tensor, is_neox_style: bool +) -> Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + return apply_rotary_emb_torch(x, cos, sin, is_neox_style) + + +@perftest() +def run_torch_mrope_3d_rms( + qkv: Tensor, # contiguous (num_tokens * (num_heads_q + num_heads_k + num_heads_v) * head_size) + qw: Tensor, # contiguous (head_size) + kw: Tensor, # contiguous (head_size) + cos_sin: Tensor, # contiguous (max_positions * head_size) + positions: Tensor, # contiguous (3 * num_tokens) + num_tokens: int, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_size: int, + is_neox_style: bool, + mrope_section: List[int], + is_interleaved: bool, + eps: float, +): + q_size = num_heads_q * head_size + k_size = num_heads_k * head_size + v_size = num_heads_v * head_size + qkv = qkv.view(num_tokens, q_size + k_size + v_size) + q, k, v = qkv.split([q_size, k_size, v_size], dim=-1) + + q_by_head = q.view(num_tokens, num_heads_q, head_size) + q_by_head = rms_norm_forward(q_by_head, qw, eps) + q = q_by_head.view(q.shape) + + k_by_head = k.view(num_tokens, num_heads_k, head_size) + k_by_head = rms_norm_forward(k_by_head, kw, eps) + k = k_by_head.view(k.shape) + + cos_sin = cos_sin.view(max_positions, head_size) + positions = positions.view(3, num_tokens) + cos_sin = cos_sin[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + + if is_interleaved: + cos = apply_interleaved_rope(cos, mrope_section) + sin = apply_interleaved_rope(sin, mrope_section) + else: + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(mrope_section, dim=-1))], + dim=-1, + ) + + q_shape = q.shape + q = q.view(num_tokens, -1, head_size) + q = apply_rotary_emb_dispatch(q, cos, sin, is_neox_style) + q = q.reshape(q_shape) + + k_shape = k.shape + k = k.view(num_tokens, -1, head_size) + k = apply_rotary_emb_dispatch(k, cos, sin, is_neox_style) + k = k.reshape(k_shape) + + return q, k, v + + +@perftest() +def run_aiter_mrope_3d_rms( + qkv: Tensor, # contiguous (num_tokens * (num_heads_q + num_heads_k + num_heads_v) * head_size) + qw: Tensor, # contiguous (head_size) + kw: Tensor, # contiguous (head_size) + cos_sin: Tensor, # contiguous (max_positions * head_size) + positions: Tensor, # contiguous (3 * num_tokens) + num_tokens: int, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_size: int, + is_neox_style: bool, + mrope_section: List[int], + is_interleaved: bool, + eps: float, +): + qkv = qkv.clone() # inplace op + aiter.fused_mrope_3d_rms( + qkv, + qw, + kw, + cos_sin, + positions, + num_tokens, + num_heads_q, + num_heads_k, + num_heads_v, + head_size, + is_neox_style, + mrope_section, + is_interleaved, + eps, + ) + + q_size = num_heads_q * head_size + k_size = num_heads_k * head_size + v_size = num_heads_v * head_size + + qkv = qkv.view(num_tokens, q_size + k_size + v_size) + q, k, v = qkv.split([q_size, k_size, v_size], dim=-1) + return q, k, v + + +@benchmark() +def test_mrope_3d_rms( + dtype, + num_tokens, + num_heads_q, + num_heads_k, + num_heads_v, + head_size, + is_neox_style, + mrope_section, + is_interleaved, + eps=1e-6, +): + qkv = torch.randn( + (num_tokens, num_heads_q + num_heads_k + num_heads_v, head_size), + dtype=dtype, + device="cuda", + ) + qw = torch.randn(head_size, dtype=dtype, device="cuda") + kw = torch.randn(head_size, dtype=dtype, device="cuda") + cos_sin = torch.randn((max_positions, head_size), dtype=dtype, device="cuda") + positions = torch.randint( + 0, max_positions, (3, num_tokens), dtype=torch.int64, device="cuda" + ) + + (q_ref, k_ref, v_ref), avg_torch = run_torch_mrope_3d_rms( + qkv, + qw, + kw, + cos_sin, + positions, + num_tokens, + num_heads_q, + num_heads_k, + num_heads_v, + head_size, + is_neox_style, + mrope_section, + is_interleaved, + eps, + ) + (q, k, v), avg_cu = run_aiter_mrope_3d_rms( + qkv, + qw, + kw, + cos_sin, + positions, + num_tokens, + num_heads_q, + num_heads_k, + num_heads_v, + head_size, + is_neox_style, + mrope_section, + is_interleaved, + eps, + ) + + info = f"dtype:{dtype}, num_tokens:{num_tokens}, num_heads_q:{num_heads_q}, num_heads_k:{num_heads_k}, num_heads_v:{num_heads_v}, head_size:{head_size}, is_neox_style:{is_neox_style}" + info += ( + f", mrope_section:{mrope_section}, is_interleaved:{is_interleaved}, eps:{eps}" + ) + msg = f"[perf] === {info} === torch avg: {avg_torch:<8.2f} us, cu avg: {avg_cu:<8.2f} us, uplift: {avg_torch/avg_cu-1:<5.1%}" + checkAllclose(q_ref, q, msg="q", rtol=1e-2, atol=0.05) + checkAllclose(k_ref, k, msg="k", rtol=1e-2, atol=0.05) + checkAllclose(v_ref, v, msg=msg, rtol=1e-2, atol=0.05) + + +if __name__ == "__main__": + is_neox_styles = [True, False] + num_tokens = [513, 1257, 127, 778, 10024, 3] + num_heads = [32, 64] + head_sizes = [64, 128, 256] + mrope_sections = [[12, 10, 10], [24, 20, 20], [48, 40, 40]] + is_interleaveds = [True, False] + max_positions = 10000 + dtype = torch.bfloat16 + for is_neox_style in is_neox_styles: + for num_token in num_tokens: + for num_head in num_heads: + for i, head_size in enumerate(head_sizes): + ms = mrope_sections[i] + for is_interleaved in is_interleaveds: + test_mrope_3d_rms( + dtype, + num_token, + num_head, + num_head, + num_head, + head_size, + is_neox_style, + ms, + is_interleaved, + eps=1e-6, + ) + print("done")