diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index c45ebd34729b..d03c6a5cf0dd 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -7,23 +7,23 @@ namespace vllm { -template +template inline __device__ void apply_token_rotary_embedding( - scalar_t* __restrict__ arr, const float* __restrict__ cos_ptr, - const float* __restrict__ sin_ptr, int rot_offset, int embed_dim, + scalar_t* __restrict__ arr, const cache_t* __restrict__ cos_ptr, + const cache_t* __restrict__ sin_ptr, int rot_offset, int embed_dim, const bool inverse) { int x_index, y_index; float cos_f, sin_f; if (IS_NEOX) { x_index = rot_offset; y_index = embed_dim + rot_offset; - cos_f = VLLM_LDG(cos_ptr + x_index); - sin_f = VLLM_LDG(sin_ptr + x_index); + cos_f = static_cast(VLLM_LDG(cos_ptr + x_index)); + sin_f = static_cast(VLLM_LDG(sin_ptr + x_index)); } else { x_index = 2 * rot_offset; y_index = 2 * rot_offset + 1; - cos_f = VLLM_LDG(cos_ptr + x_index / 2); - sin_f = VLLM_LDG(sin_ptr + x_index / 2); + cos_f = static_cast(VLLM_LDG(cos_ptr + x_index / 2)); + sin_f = static_cast(VLLM_LDG(sin_ptr + x_index / 2)); } if (inverse) { sin_f = -sin_f; @@ -34,7 +34,7 @@ inline __device__ void apply_token_rotary_embedding( arr[y_index] = static_cast(y_f * cos_f + x_f * sin_f); } -template +template inline __device__ void apply_rotary_embedding( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, @@ -43,14 +43,14 @@ inline __device__ void apply_rotary_embedding( // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] - const float* cache_ptr, const int head_size, const int num_heads, + const cache_t* cache_ptr, const int head_size, const int num_heads, const int num_kv_heads, const int rot_dim, const int token_idx, const int64_t query_stride, const int64_t key_stride, const int64_t head_stride, const int64_t rope_dim_offset, const bool inverse) { const int embed_dim = rot_dim / 2; - const float* cos_ptr = cache_ptr; - const float* sin_ptr = cache_ptr + embed_dim; + const cache_t* cos_ptr = cache_ptr; + const cache_t* sin_ptr = cache_ptr + embed_dim; const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { @@ -58,7 +58,7 @@ inline __device__ void apply_rotary_embedding( const int64_t token_head = token_idx * query_stride + head_idx * head_stride + rope_dim_offset; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding( + apply_token_rotary_embedding( query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse); } @@ -69,13 +69,13 @@ inline __device__ void apply_rotary_embedding( const int64_t token_head = token_idx * key_stride + head_idx * head_stride + rope_dim_offset; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding( + apply_token_rotary_embedding( key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse); } } } -template +template __global__ void rotary_embedding_kernel( const int64_t* __restrict__ positions, // [batch_size, seq_len] or // [num_tokens] @@ -86,15 +86,15 @@ __global__ void rotary_embedding_kernel( // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] - const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] fp32 + const cache_t* __restrict__ cos_sin_cache, // [max_position, rot_dim] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int64_t head_stride, const int num_heads, const int num_kv_heads, const int head_size, const int64_t rope_dim_offset, const bool inverse) { const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; - const float* cache_ptr = cos_sin_cache + pos * rot_dim; + const cache_t* cache_ptr = cos_sin_cache + pos * rot_dim; - apply_rotary_embedding( + apply_rotary_embedding( query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride, head_stride, rope_dim_offset, inverse); @@ -168,23 +168,28 @@ void rotary_embedding( dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - auto cache_f32 = cos_sin_cache.to(torch::kFloat32); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { - if (is_neox) { - vllm::rotary_embedding_kernel<<>>( - positions.data_ptr(), query.data_ptr(), - key.has_value() ? key->data_ptr() : nullptr, - cache_f32.data_ptr(), rot_dim, query_stride, key_stride, - head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset, - inverse); - } else { - vllm::rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.has_value() ? key->data_ptr() : nullptr, - cache_f32.data_ptr(), rot_dim, query_stride, key_stride, - head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset, - inverse); - } + using query_t = scalar_t; + VLLM_DISPATCH_FLOATING_TYPES( + cos_sin_cache.scalar_type(), "rotary_embedding_cache", [&] { + using cache_t = scalar_t; + if (is_neox) { + vllm::rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, head_stride, num_heads, num_kv_heads, head_size, + rope_dim_offset, inverse); + } else { + vllm::rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, head_stride, num_heads, num_kv_heads, head_size, + rope_dim_offset, inverse); + } + }); }); } diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index 6cdd94fdc865..8410d1f1bcc6 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -35,6 +35,9 @@ def rotary_embedding_opcheck( @pytest.mark.parametrize("seq_len", [11, 1024]) @pytest.mark.parametrize("use_key", [True, False]) @pytest.mark.parametrize("head_stride_is_contiguous", [True, False]) +@pytest.mark.parametrize( + "dtype", [torch.float32, torch.bfloat16] +) def test_rotary_embedding_opcheck( default_vllm_config, dist_init, @@ -46,19 +49,20 @@ def test_rotary_embedding_opcheck( seq_len, use_key, head_stride_is_contiguous, + dtype, ): batch_size = 1 base = 10000 num_heads = 7 rot = RotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, torch.float32 + head_size, rotary_dim, max_position, base, is_neox_style, dtype ) positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) head_stride = head_size + (64 if head_stride_is_contiguous else 0) query = torch.randn( - batch_size, seq_len, num_heads, head_stride, dtype=torch.float32, device=device + batch_size, seq_len, num_heads, head_stride, dtype=dtype, device=device ) key = torch.randn_like(query) if use_key else None query = query[..., :head_size]