Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 39 additions & 34 deletions csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@

namespace vllm {

template <typename scalar_t, bool IS_NEOX>
template <typename scalar_t, typename cache_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr, const float* __restrict__ cos_ptr,
const float* __restrict__ sin_ptr, int rot_offset, int embed_dim,
scalar_t* __restrict__ arr, const cache_t* __restrict__ cos_ptr,
const cache_t* __restrict__ sin_ptr, int rot_offset, int embed_dim,
const bool inverse) {
int x_index, y_index;
float cos_f, sin_f;
if (IS_NEOX) {
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos_f = VLLM_LDG(cos_ptr + x_index);
sin_f = VLLM_LDG(sin_ptr + x_index);
cos_f = static_cast<float>(VLLM_LDG(cos_ptr + x_index));
sin_f = static_cast<float>(VLLM_LDG(sin_ptr + x_index));
} else {
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos_f = VLLM_LDG(cos_ptr + x_index / 2);
sin_f = VLLM_LDG(sin_ptr + x_index / 2);
cos_f = static_cast<float>(VLLM_LDG(cos_ptr + x_index / 2));
sin_f = static_cast<float>(VLLM_LDG(sin_ptr + x_index / 2));
}
if (inverse) {
sin_f = -sin_f;
Expand All @@ -34,7 +34,7 @@ inline __device__ void apply_token_rotary_embedding(
arr[y_index] = static_cast<scalar_t>(y_f * cos_f + x_f * sin_f);
}

template <typename scalar_t, bool IS_NEOX>
template <typename scalar_t, typename cache_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
Expand All @@ -43,22 +43,22 @@ inline __device__ void apply_rotary_embedding(
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const float* cache_ptr, const int head_size, const int num_heads,
const cache_t* cache_ptr, const int head_size, const int num_heads,
const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride, const int64_t rope_dim_offset,
const bool inverse) {
const int embed_dim = rot_dim / 2;
const float* cos_ptr = cache_ptr;
const float* sin_ptr = cache_ptr + embed_dim;
const cache_t* cos_ptr = cache_ptr;
const cache_t* sin_ptr = cache_ptr + embed_dim;

const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head =
token_idx * query_stride + head_idx * head_stride + rope_dim_offset;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
apply_token_rotary_embedding<scalar_t, cache_t, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse);
}

Expand All @@ -69,13 +69,13 @@ inline __device__ void apply_rotary_embedding(
const int64_t token_head =
token_idx * key_stride + head_idx * head_stride + rope_dim_offset;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
apply_token_rotary_embedding<scalar_t, cache_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse);
}
}
}

template <typename scalar_t, bool IS_NEOX>
template <typename scalar_t, typename cache_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
// [num_tokens]
Expand All @@ -86,15 +86,15 @@ __global__ void rotary_embedding_kernel(
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] fp32
const cache_t* __restrict__ cos_sin_cache, // [max_position, rot_dim]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size, const int64_t rope_dim_offset, const bool inverse) {
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const float* cache_ptr = cos_sin_cache + pos * rot_dim;
const cache_t* cache_ptr = cos_sin_cache + pos * rot_dim;

apply_rotary_embedding<scalar_t, IS_NEOX>(
apply_rotary_embedding<scalar_t, cache_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride, head_stride, rope_dim_offset,
inverse);
Expand Down Expand Up @@ -168,23 +168,28 @@ void rotary_embedding(
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto cache_f32 = cos_sin_cache.to(torch::kFloat32);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cache_f32.data_ptr<float>(), rot_dim, query_stride, key_stride,
head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset,
inverse);
} else {
vllm::rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cache_f32.data_ptr<float>(), rot_dim, query_stride, key_stride,
head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset,
inverse);
}
using query_t = scalar_t;
VLLM_DISPATCH_FLOATING_TYPES(
cos_sin_cache.scalar_type(), "rotary_embedding_cache", [&] {
using cache_t = scalar_t;
if (is_neox) {
vllm::rotary_embedding_kernel<query_t, cache_t, true>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<query_t>(),
key.has_value() ? key->data_ptr<query_t>() : nullptr,
cos_sin_cache.data_ptr<cache_t>(), rot_dim, query_stride,
key_stride, head_stride, num_heads, num_kv_heads, head_size,
rope_dim_offset, inverse);
} else {
vllm::rotary_embedding_kernel<query_t, cache_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<query_t>(),
key.has_value() ? key->data_ptr<query_t>() : nullptr,
cos_sin_cache.data_ptr<cache_t>(), rot_dim, query_stride,
key_stride, head_stride, num_heads, num_kv_heads, head_size,
rope_dim_offset, inverse);
}
});
});
Comment on lines 171 to 194

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The nested dispatch of floating-point types for both the query and the cache significantly increases the number of kernel instantiations (3x3x2 = 18 combinations). While this provides flexibility, it can lead to increased compilation times and binary size. Given that the RoPE cache is almost always either float32 or the same type as the model weights, consider if a more restricted dispatch (e.g., only allowing the cache to be float32 or matching the query type) would be sufficient to avoid the Cartesian product of types.

}
8 changes: 6 additions & 2 deletions tests/kernels/core/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def rotary_embedding_opcheck(
@pytest.mark.parametrize("seq_len", [11, 1024])
@pytest.mark.parametrize("use_key", [True, False])
@pytest.mark.parametrize("head_stride_is_contiguous", [True, False])
@pytest.mark.parametrize(
"dtype", [torch.float32, torch.bfloat16]
)
def test_rotary_embedding_opcheck(
default_vllm_config,
dist_init,
Expand All @@ -46,19 +49,20 @@ def test_rotary_embedding_opcheck(
seq_len,
use_key,
head_stride_is_contiguous,
dtype,
):
batch_size = 1
base = 10000
num_heads = 7
rot = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, torch.float32
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)

positions = torch.randint(0, max_position, (batch_size, seq_len), device=device)
head_stride = head_size + (64 if head_stride_is_contiguous else 0)

query = torch.randn(
batch_size, seq_len, num_heads, head_stride, dtype=torch.float32, device=device
batch_size, seq_len, num_heads, head_stride, dtype=dtype, device=device
)
key = torch.randn_like(query) if use_key else None
query = query[..., :head_size]
Expand Down
Loading