Skip to content

Commit

Permalink
fuse rope
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Jan 26, 2025
1 parent a8f3839 commit 9621f28
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 41 deletions.
141 changes: 141 additions & 0 deletions csrc/gpu/fused_rotary_position_encoding.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/extension.h"

template <typename T, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding_kernel(
T* __restrict__ arr,
const T* __restrict__ cos_ptr,
const T* __restrict__ sin_ptr,
int rot_offset,
int embed_dim) {
int x_index, y_index;
T cos, sin;
if (IS_NEOX) {
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = cos_ptr[x_index];
sin = sin_ptr[x_index];
} else {
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = cos_ptr[x_index / 2];
sin = sin_ptr[x_index / 2];
}

const T x = arr[x_index];
const T y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}


template <typename T, bool IS_NEOX>
__global__ void apply_rotary_embedding_kernel(
T* __restrict__ query, // [num_tokens, num_heads, head_size]
T* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const int* __restrict__ position_ids, // [num_tokens]
const T* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int pos = position_ids[token_idx];
const T* cache_ptr = cos_sin_cache + pos * rot_dim;

const int embed_dim = rot_dim / 2;
const T* cos_ptr = cache_ptr;
const T* sin_ptr = cache_ptr + embed_dim;

const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}

const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
}


void FusedRotaryPositionEncoding(
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
// [num_tokens, num_heads * head_size]
paddle::Tensor& key,
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
// head_size]
const paddle::Tensor& position_ids, // [num_tokens]
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
int head_size,
bool is_neox) {
int64_t num_tokens = query.dims()[0];
int num_heads = query.numel() / num_tokens / head_size;
int num_kv_heads = key.numel() / num_tokens / head_size;
int rot_dim = cos_sin_cache.dims()[1];
int64_t query_stride = num_heads * head_size;
int64_t key_stride = num_kv_heads * head_size;

dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
PD_DISPATCH_FLOATING_TYPES(
query.dtype(), "apply_rotary_embedding_kernel", [&] {
if (is_neox) {
apply_rotary_embedding_kernel<data_t, true>
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
key.data<data_t>(),
position_ids.data<int>(),
cos_sin_cache.data<data_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
apply_rotary_embedding_kernel<data_t, false>
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
key.data<data_t>(),
position_ids.data<int>(),
cos_sin_cache.data<data_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}

PD_BUILD_OP(fused_rotary_position_encoding)
.Inputs({"query", "key", "position_ids", "cos_sin_cache"})
.Outputs({"query_out", "key_out"})
.Attrs({"head_size: int", "is_neox: bool"})
.SetInplaceMap({{"query", "query_out"}, {"key", "key_out"}})
.SetKernelFn(PD_KERNEL(FusedRotaryPositionEncoding));
1 change: 1 addition & 0 deletions csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def get_gencode_flags():
"./gpu/quant_int8.cu",
"./gpu/dequant_int8.cu",
"./gpu/get_position_ids.cu",
"./gpu/fused_rotary_position_encoding.cu",
"./gpu/flash_attn_bwd.cc",
"./gpu/tune_cublaslt_gemm.cu",
"./gpu/sample_kernels/top_p_sampling_reject.cu",
Expand Down
37 changes: 9 additions & 28 deletions paddlenlp/experimental/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

Check warning on line 14 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L14

Added line #L14 was not covered by tests

import os
from functools import partial
from typing import Tuple

Check warning on line 18 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L16-L18

Added lines #L16 - L18 were not covered by tests

Expand Down Expand Up @@ -60,7 +61,6 @@ class DeepseekScalingRotaryEmbedding(nn.Layer):

def __init__(

Check warning on line 62 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L62

Added line #L62 was not covered by tests
self,
# head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
Expand Down Expand Up @@ -114,10 +114,9 @@ def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
def _compute_cos_sin_cache(self) -> paddle.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = paddle.arange(self.max_position_embeddings * self.scaling_factor, dtype=paddle.float32)
freqs = paddle.einsum("i,j -> ij", t, inv_freq)
emb = paddle.concat((freqs, freqs), axis=-1)
cos = emb.cos() * self.mscale
sin = emb.sin() * self.mscale
freqs = paddle.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
cache = paddle.concat((cos, sin), axis=-1)
return cache.cast(self._dtype)

Check warning on line 121 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L114-L121

Added lines #L114 - L121 were not covered by tests

Expand All @@ -127,29 +126,17 @@ def forward(
query: paddle.Tensor,
key: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
cos_sin = self.cos_sin_cache[position_ids].unsqueeze(1)
cos, sin = cos_sin.chunk(2, axis=-1)
from paddlenlp_ops import fused_rotary_position_encoding

Check warning on line 129 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L129

Added line #L129 was not covered by tests

s, h, d = query.shape
query = query.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])

s, h, d = key.shape
key = key.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])

def rotate_half(x):
"""Rotates half the hidden axiss of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x

query = (query * cos) + (rotate_half(query) * sin)
key = (key * cos) + (rotate_half(key) * sin)
# In-place operations that update the query and key tensors.
os.environ["stride_in_no_check_dy2st_diff"] = "1"
fused_rotary_position_encoding(query, key, position_ids, self.cos_sin_cache, self.rotary_dim, False)

Check warning on line 133 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L132-L133

Added lines #L132 - L133 were not covered by tests

return query, key

Check warning on line 135 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L135

Added line #L135 was not covered by tests


class DeepseekV2RMSNorm(nn.Layer):
def __init__(self, config):
def __init__(self, config: DeepseekV2Config):
super().__init__()
self.eps = config.rms_norm_eps
self.weight = paddle.create_parameter(

Check warning on line 142 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L138-L142

Added lines #L138 - L142 were not covered by tests
Expand All @@ -170,9 +157,7 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):

self.config = config

Check warning on line 158 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L158

Added line #L158 was not covered by tests

self.append_attn = config.append_attn
self.max_seq_len = config.max_seq_len

Check warning on line 160 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L160

Added line #L160 was not covered by tests
self.block_size = config.block_size

self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
Expand All @@ -181,12 +166,9 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
self.num_key_value_heads = config.num_key_value_heads
self.num_layers = config.num_hidden_layers
self.rms_norm_eps = config.rms_norm_eps
self.max_position_embeddings = config.max_position_embeddings
self.quant_type = config.quant_type
self.rope_theta = config.rope_theta

Check warning on line 170 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L162-L170

Added lines #L162 - L170 were not covered by tests

self.use_neox = False

self.use_weight_only = False
if config.quant_type == "weight_only_int8":
self.use_weight_only = True
Expand Down Expand Up @@ -499,7 +481,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
rope_theta=self.rope_theta,
rotary_emb=self.rotary_emb,
norm_type="rmsnorm",
use_neox_rotary_style=self.use_neox,
rank_id=config.tensor_parallel_rank,
moe_config=moe_config,
mla_config=mla_config,
Expand Down
25 changes: 13 additions & 12 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,13 +974,13 @@ def compute_qkv_linear(self, ln_out, i):
query = paddle.matmul(ln_out, self.q_proj_weights[i])

Check warning on line 974 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L974

Added line #L974 was not covered by tests

query = query.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
query_nope, query_pe = paddle.split(
query, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.qk_rope_head_dim], axis=-1
query_nope, query_pe = query.split(

Check warning on line 977 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L976-L977

Added lines #L976 - L977 were not covered by tests
[self.config.mla_config.qk_nope_head_dim, self.config.mla_config.qk_rope_head_dim], axis=-1
)

compressed_kv = paddle.matmul(ln_out, self.kv_a_proj_with_mqa_weights[i])
compressed_kv, key_pe = paddle.split(
compressed_kv, [self.config.mla_config.kv_lora_rank, self.config.mla_config.qk_rope_head_dim], axis=-1
compressed_kv, key_pe = compressed_kv.split(

Check warning on line 982 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L981-L982

Added lines #L981 - L982 were not covered by tests
[self.config.mla_config.kv_lora_rank, self.config.mla_config.qk_rope_head_dim], axis=-1
)
key_pe = key_pe.reshape([-1, 1, self.config.mla_config.qk_rope_head_dim])
compressed_kv = self.norm_func(

Check warning on line 986 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L985-L986

Added lines #L985 - L986 were not covered by tests
Expand All @@ -994,8 +994,8 @@ def compute_qkv_linear(self, ln_out, i):
key_value = key_value.reshape(

Check warning on line 994 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L993-L994

Added lines #L993 - L994 were not covered by tests
[-1, self.num_heads, self.config.mla_config.qk_nope_head_dim + self.config.mla_config.v_head_dim]
)
key_nope, value = paddle.split(
key_value, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
key_nope, value = key_value.split(

Check warning on line 997 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L997

Added line #L997 was not covered by tests
[self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
)

query_pe, key_pe = self.config.rotary_emb(self.position_ids, query_pe, key_pe)

Check warning on line 1001 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1001

Added line #L1001 was not covered by tests
Expand Down Expand Up @@ -1305,6 +1305,7 @@ def pre_process(self, **kwargs):

from paddlenlp_ops import get_position_ids

Check warning on line 1306 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1306

Added line #L1306 was not covered by tests

# In-place operations that compute the position_ids.
get_position_ids(seq_lens_encoder, seq_lens_decoder, self.position_ids)

Check warning on line 1309 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1309

Added line #L1309 was not covered by tests

def post_process(self, **kwargs):
Expand Down Expand Up @@ -1827,8 +1828,8 @@ def compute_qkv_linear(self, ln_out, i):
)

query = query.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
query_nope, query_pe = paddle.split(
query, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.qk_rope_head_dim], axis=-1
query_nope, query_pe = query.split(

Check warning on line 1831 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1830-L1831

Added lines #L1830 - L1831 were not covered by tests
[self.config.mla_config.qk_nope_head_dim, self.config.mla_config.qk_rope_head_dim], axis=-1
)

compressed_kv = weight_only_linear(

Check warning on line 1835 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1835

Added line #L1835 was not covered by tests
Expand All @@ -1837,8 +1838,8 @@ def compute_qkv_linear(self, ln_out, i):
weight_scale=self.kv_a_proj_with_mqa_weights_scale[i],
weight_dtype=self.weight_dtype,
)
compressed_kv, key_pe = paddle.split(
compressed_kv, [self.config.mla_config.kv_lora_rank, self.config.mla_config.qk_rope_head_dim], axis=-1
compressed_kv, key_pe = compressed_kv.split(

Check warning on line 1841 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1841

Added line #L1841 was not covered by tests
[self.config.mla_config.kv_lora_rank, self.config.mla_config.qk_rope_head_dim], axis=-1
)
key_pe = key_pe.reshape([-1, 1, self.config.mla_config.qk_rope_head_dim])
compressed_kv = self.norm_func(

Check warning on line 1845 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1844-L1845

Added lines #L1844 - L1845 were not covered by tests
Expand All @@ -1857,8 +1858,8 @@ def compute_qkv_linear(self, ln_out, i):
key_value = key_value.reshape(

Check warning on line 1858 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1858

Added line #L1858 was not covered by tests
[-1, self.num_heads, self.config.mla_config.qk_nope_head_dim + self.config.mla_config.v_head_dim]
)
key_nope, value = paddle.split(
key_value, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
key_nope, value = key_value.split(

Check warning on line 1861 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1861

Added line #L1861 was not covered by tests
[self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
)

query_pe, key_pe = self.config.rotary_emb(self.position_ids, query_pe, key_pe)

Check warning on line 1865 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1865

Added line #L1865 was not covered by tests
Expand Down
1 change: 0 additions & 1 deletion paddlenlp/experimental/transformers/qwen2_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def __init__(self, config: Qwen2MoeConfig):
self.num_key_value_heads = config.num_key_value_heads
self.num_layers = config.num_hidden_layers
self.rms_norm_eps = config.rms_norm_eps
self.max_position_embeddings = config.max_position_embeddings
self.quant_type = config.quant_type
self.rope_theta = config.rope_theta

Expand Down

0 comments on commit 9621f28

Please sign in to comment.