diff --git a/csrc/kernels/pos_encoding_kernels.cpp b/csrc/kernels/pos_encoding_kernels.cpp deleted file mode 100644 index 69a15191dda..00000000000 --- a/csrc/kernels/pos_encoding_kernels.cpp +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel_operator.h" -#include -#include "types.h" -#include "utils.h" - - -using vllm_ascend::AccType; -using vllm_ascend::local_mem_copy; -template class RotaryEmbedding { - // NOTE(ganyi): we use 512B as load stride for pipe, need to find another way to - // retrieve this size from runtime for more Soc support - #if (__CCE_AICORE__ >= 220) - static int constexpr loadSize = 512; - #else - static int constexpr loadSize = 1024 * 4; - #endif - using dst_t = scalar_t; - using acc_t = typename AccType::type; - // only half tensor have cast instruct to int8, hardcode acc_dst_t as half - using local_scalar_t = AscendC::LocalTensor; - using local_acc_t = AscendC::LocalTensor; - using local_dst_t = AscendC::LocalTensor; - -public: - __aicore__ inline RotaryEmbedding() - { - } - - // Allocate buffers for input and output queue and the temp buffer used during kernel compute process, - // this init process happens only in the kernel compute on a single vector core. - __aicore__ inline void init(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst, - __gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache, - const int rotDim, const int64_t dstQueryStride, - const int64_t dstKeyStride, const int64_t queryStride, const int64_t keyStride, - const int numHeads, const int numKvHeads, const int headSize, AscendC::TPipe *pipe) - { - pipe_ = pipe; - rotDim_ = rotDim; - // query stride and key stride is used to handle the strided tensor which is not contiguous on num_tokens dim - queryStride_ = queryStride; - keyStride_ = keyStride; - dstQueryStride_ = dstQueryStride; - dstKeyStride_ = dstKeyStride; - numHeads_ = numHeads; - numKvHeads_ = numKvHeads; - headSize_ = headSize; - embedDim_ = rotDim / 2; - - pipe_->InitBuffer(inQue_, 1 /* buffer_num */, loadSize /* buffer_size */); - pipe_->InitBuffer(inQueSinCos_, 1 /* buffer_num */, rotDim_ * sizeof(scalar_t) /* buffer_size */); - pipe_->InitBuffer(outQue_, 1 /* buffer_num */, loadSize /* buffer_size */); - // 2 temporary calculation buffer - calcTmpBufferOffset_ = 0; - // 1 upcast buffer for bf16 (headSize) - upcastInputBufferOffset_ = calcTmpBufferOffset_ + sizeof(acc_t) * embedDim_ * 2; - // 1 upcast temp buffer for bf16 (2 * embed_dim) - upcastTempBufferOffset_ = upcastInputBufferOffset_ + sizeof(acc_t) * headSize_; - // 2 sin cos upcast buffer for bf16 - cosSinUpcastBufferOffset_ = upcastTempBufferOffset_ + sizeof(acc_t) * 2 * embedDim_; - // 2. bf16 path: needs 2 cos sin upcast buffer size - // 3. fp16 path: needs 2 temporary calculation buffer size - tempBufferSize_ = cosSinUpcastBufferOffset_ + 2 * embedDim_ * sizeof(acc_t); - // need to consider upcast the bf16 to fp32, so we might need 4 buffer just in case - // 2 temporary buffer, 2 input buffer, 1 cos buffer, 1 sin buffer, 2 scale buffer (headSize), 2 zp - // buffer(headSize int8), 1 dst_temp buffer(headSize, int32) - pipe_->InitBuffer(calcBuf_, tempBufferSize_ /* buffer_size */); - if constexpr (!std::is_same_v) { - pipe_->InitBuffer(copyBuf_, loadSize); - } - } - __aicore__ inline void update_mem_offset(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst, - __gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache, - const int rotDim, const int64_t dstQueryStride, const int64_t dstKeyStride, - const int64_t queryStride, const int64_t keyStride, const int numHeads, - const int numKvHeads, const int headSize, const int64_t idx) - { - int64_t pos = positions[idx]; - cosSin_.SetGlobalBuffer(cosSinCache + pos * rotDim_, rotDim_); - query_.SetGlobalBuffer(query + queryStride * idx, headSize * numHeads_); - key_.SetGlobalBuffer(key + keyStride * idx, headSize * numKvHeads_); - queryDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(queryDst) + dstQueryStride * idx, - headSize * numHeads_); - keyDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(keyDst) + dstKeyStride * idx, headSize * numKvHeads_); - } - - // compute per head for neox on bf16 - template , void>::type * = nullptr> - __aicore__ inline void - neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor sin, AscendC::LocalTensor cos, - AscendC::LocalTensor upcastInputBuffer, AscendC::LocalTensor calcTmpBuffer) - { - // slice dst - local_dst_t dstX = dst; - local_dst_t dstY = dst[embedDim_]; - - // slice src - local_scalar_t srcX = src; - local_scalar_t srcY = src[embedDim_]; - - // slice temp buffer - local_acc_t calcTmpBufferX = calcTmpBuffer; - local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_]; - - // slice upcast input buffer - local_acc_t upcastBufferX = upcastInputBuffer; - local_acc_t upcastBufferY = upcastBufferX[embedDim_]; - - // dst x calc - Cast(upcastInputBuffer, src, AscendC::RoundMode::CAST_NONE, headSize_); - Mul(calcTmpBufferX, upcastBufferX, cos, embedDim_); - Mul(calcTmpBufferY, upcastBufferY, sin, embedDim_); - Sub(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_); - Cast(dstX, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_); - - // dst y calc - Mul(calcTmpBufferX, upcastBufferX, sin, embedDim_); - Mul(calcTmpBufferY, upcastBufferY, cos, embedDim_); - Add(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_); - Cast(dstY, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_); - } - - // compute per head output for neox - template , void>::type * = nullptr> - __aicore__ inline void - neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor sin, AscendC::LocalTensor cos, - AscendC::LocalTensor upcastInputBuffer, AscendC::LocalTensor calcTmpBuffer) - { - // slice dst buffer - local_dst_t dstX = dst; - local_dst_t dstY = dst[embedDim_]; - // slice src buffer - local_scalar_t srcX = src; - local_scalar_t srcY = src[embedDim_]; - // slice temp buffer - local_acc_t calcTmpBufferX = calcTmpBuffer; - local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_]; - - // dst x calc - Mul(calcTmpBufferX, srcX, cos, embedDim_); - Mul(calcTmpBufferY, srcY, sin, embedDim_); - Sub(dstX, calcTmpBufferX, calcTmpBufferY, embedDim_); - - // dst y calc - Mul(calcTmpBufferX, srcX, sin, embedDim_); - Mul(calcTmpBufferY, srcY, cos, embedDim_); - Add(dstY, calcTmpBufferX, calcTmpBufferY, embedDim_); - } - - __aicore__ inline void compute_qk(AscendC::GlobalTensor srcG, AscendC::GlobalTensor dstG, - local_acc_t localCos, local_acc_t localSin, local_acc_t upcastInputBuffer, - local_acc_t calcTmpBuffer, int loopCnt, int tailHeads, int loadStride, - int headNumPerLoad) - { - for (int loopNum = 0; loopNum < loopCnt; ++loopNum) { - local_scalar_t src = inQue_.AllocTensor(); - local_dst_t dst = outQue_.AllocTensor(); - AscendC::DataCopy(src, srcG[loopNum * loadStride], loadStride); - inQue_.EnQue(src); - - local_scalar_t srcDeque = inQue_.DeQue(); - if constexpr (!std::is_same_v) { - int elem_num = loadStride / sizeof(scalar_t); - AscendC::LocalTensor upBuffer = copyBuf_.GetWithOffset(elem_num, 0); - Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num); - Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num); - } else { - local_mem_copy(dst, srcDeque, loadStride); - } - for (int i = 0; i < headNumPerLoad; ++i) { - neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer, - calcTmpBuffer); - } - outQue_.EnQue(dst); - local_dst_t dstDeque = outQue_.DeQue(); - AscendC::DataCopy(dstG[loopNum * loadStride], dstDeque, loadStride); - outQue_.FreeTensor(dstDeque); - inQue_.FreeTensor(srcDeque); - } - // process tail - { - local_scalar_t src = inQue_.AllocTensor(); - local_dst_t dst = outQue_.AllocTensor(); - - AscendC::DataCopy(src, srcG[loopCnt * loadStride], tailHeads * headSize_); - inQue_.EnQue(src); - local_scalar_t srcDeque = inQue_.DeQue(); - - if constexpr (!std::is_same_v) { - int elem_num = tailHeads * headSize_ / sizeof(scalar_t); - AscendC::LocalTensor upBuffer = copyBuf_.GetWithOffset(elem_num, 0); - Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num); - Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num); - } else { - local_mem_copy(dst, srcDeque, tailHeads * headSize_); - } - - for (int i = 0; i < tailHeads; ++i) { - neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer, - calcTmpBuffer); - } - outQue_.EnQue(dst); - local_dst_t dstDeque = outQue_.DeQue(); - AscendC::DataCopy(dstG[loopCnt * loadStride], dstDeque, tailHeads * headSize_); - outQue_.FreeTensor(dstDeque); - inQue_.FreeTensor(srcDeque); - } - } - - __aicore__ inline void compute_function() - { - local_scalar_t cosSinLocal = inQueSinCos_.AllocTensor(); - - AscendC::DataCopy(cosSinLocal, cosSin_, embedDim_ * 2); - - inQueSinCos_.EnQue(cosSinLocal); - local_scalar_t localSinCosDeque = inQueSinCos_.DeQue(); - local_scalar_t localCos = localSinCosDeque; - local_scalar_t localSin = localSinCosDeque[embedDim_]; - - local_acc_t calcTmpBuffer; - local_acc_t upcastInputBuffer; - local_acc_t upcastTempBuffer; - local_acc_t cosSinUpcastBuffer; - local_acc_t scaleBuffer; - local_acc_t offsetBuffer; - calcTmpBuffer = calcBuf_.GetWithOffset(embedDim_ * 2, calcTmpBufferOffset_); - upcastInputBuffer = calcBuf_.GetWithOffset(headSize_, upcastInputBufferOffset_); - upcastTempBuffer = calcBuf_.GetWithOffset(embedDim_ * 2, upcastTempBufferOffset_); - cosSinUpcastBuffer = calcBuf_.GetWithOffset(embedDim_ * 2, cosSinUpcastBufferOffset_); - - local_acc_t cosAccBuffer; - local_acc_t sinAccBuffer; - - if constexpr (!std::is_same_v) { - Cast(cosSinUpcastBuffer, localSinCosDeque, AscendC::RoundMode::CAST_NONE, 2 * embedDim_); - cosAccBuffer = cosSinUpcastBuffer; - sinAccBuffer = cosSinUpcastBuffer[embedDim_]; - } else { - cosAccBuffer = localCos; - sinAccBuffer = localSin; - } - - constexpr const int loadSizeByElem = loadSize / sizeof(scalar_t); - int64_t headNumPerLoad = loadSizeByElem / headSize_; - int64_t loopCnt = numHeads_ / headNumPerLoad; - int64_t tailHeads = numHeads_ - loopCnt * headNumPerLoad; - int64_t loadStride = headNumPerLoad * headSize_; - int64_t loopCntKv = numKvHeads_ / headNumPerLoad; - int64_t tailHeadsKv = numKvHeads_ - loopCntKv * headNumPerLoad; - compute_qk(query_, queryDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer, - calcTmpBuffer, loopCnt, tailHeads, loadStride, headNumPerLoad); - - compute_qk(key_, keyDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer, calcTmpBuffer, - loopCntKv, tailHeadsKv, loadStride, headNumPerLoad); - - inQueSinCos_.FreeTensor(localSinCosDeque); - } - -private: - AscendC::TPipe *pipe_; - AscendC::TQue inQue_, inQueSinCos_; - AscendC::TQue outQue_; - AscendC::TBuf calcBuf_; - AscendC::TBuf copyBuf_; - AscendC::GlobalTensor queryDst_; - AscendC::GlobalTensor keyDst_; - AscendC::GlobalTensor query_; - AscendC::GlobalTensor key_; - AscendC::GlobalTensor cosSin_; - int rotDim_; - int embedDim_; - int64_t queryStride_; - int64_t keyStride_; - int64_t dstQueryStride_; - int64_t dstKeyStride_; - int numHeads_; - int numKvHeads_; - int headSize_; - int calcTmpBufferOffset_; - int upcastInputBufferOffset_; - int upcastTempBufferOffset_; - int cosSinUpcastBufferOffset_; - int tempBufferSize_; -}; - -// Note: Need to use macro to instaniate all the target functions here, for the current build system dose not support template call in cpp -// We use C style symbol here for kernel compilation, cpp style kernel entry may lead to compilation failure -#define ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, NEOX) \ - extern "C" __global__ __aicore__ void rope_custom_##NEOX##_##TYPE( \ - __gm__ int64_t* positions, __gm__ void* queryDst, __gm__ void* keyDst, __gm__ TYPE* query, __gm__ TYPE* key, \ - __gm__ TYPE* cosSinCache, const int rotDim, const int64_t queryStride, const int64_t keyStride, \ - const int64_t dstQueryStride, const int64_t dstKeyStride, const int numHeads, const int numKvHeads, \ - const int headSize, const int64_t numTokens, const int loopNum, const int coreNum) \ - { \ - AscendC::TPipe pipe; \ - RotaryEmbedding op{}; \ - op.init(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \ - queryStride, keyStride, numHeads, numKvHeads, headSize, &pipe); \ - for (int64_t i = AscendC::GetBlockIdx(); i < numTokens; i += coreNum) { \ - op.update_mem_offset(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \ - queryStride, keyStride, numHeads, numKvHeads, headSize, i); \ - op.compute_function(); \ - } \ - } - -#define ROPE_CUSTOM_KERNEL_DECLARE(TYPE) \ - ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, true); \ - ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, false); - -// Declare all the kernel entry here -ROPE_CUSTOM_KERNEL_DECLARE(half) -#if (__CCE_AICORE__ >= 220) - ROPE_CUSTOM_KERNEL_DECLARE(bfloat16_t) -#endif - -namespace vllm_ascend { - -#define ROTARY_EMBEDDING_KERNEL_CALL(TYPE) \ - if (isNeox) \ - rope_custom_true_##TYPE<<>>( \ - positions, queryDst, keyDst, reinterpret_cast(query), reinterpret_cast(key), \ - reinterpret_cast(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \ - numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim); \ - else \ - rope_custom_false_##TYPE<<>>( \ - positions, queryDst, keyDst, reinterpret_cast(query), reinterpret_cast(key), \ - reinterpret_cast(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \ - numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim); - -// maximum number for runtime to launch a ascendc kernel. -// we use this to constrain the maximum number of block size -static const int64_t maxParallelSize = 65535; - -extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst, - void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim, - const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride, - const int64_t dstKeyStride, const int numHeads, const int numKvHeads, - const int headSize, const int64_t numTokens, const uint32_t loopCnt, - uint32_t aivNum) -{ - - int blockDim = maxParallelSize > numTokens ? numTokens : maxParallelSize; - if (type == AscendType::FP16) { - ROTARY_EMBEDDING_KERNEL_CALL(half); - } - #if (__CCE_AICORE__ >= 220) - else if (type == AscendType::BF16) { - ROTARY_EMBEDDING_KERNEL_CALL(bfloat16_t); - } - #endif - else { - return; - } -} - -} // namespace vllm_ascend \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index 4edf0a8fe71..018bfc7b1ba 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -24,12 +24,6 @@ #include "torch_npu/csrc/aten/common/from_blob.h" namespace vllm_ascend { - extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst, - void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim, - const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride, - const int64_t dstKeyStride, const int numHeads, const int numKvHeads, - const int headSize, const int64_t numTokens, const uint32_t loopCnt, - uint32_t aivNum); extern void get_masked_input_and_mask_impl( void* stream, diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index f2652b26dec..cfa8f7d33a6 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -105,75 +105,6 @@ AscendType get_dtype_from_torch(at::ScalarType scalarType) } } -std::tuple rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key, - int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox) -{ - int32_t deviceId = 0; - int64_t num_tokens = positions.numel(); - int positions_ndim = positions.dim(); - TORCH_CHECK( - positions_ndim == 1 || positions_ndim == 2, - "positions must have shape [num_tokens] or [batch_size, seq_len]"); - if (positions_ndim == 1) { - TORCH_CHECK( - query.size(0) == positions.size(0) && key.size(0) == positions.size(0), - "query, key and positions must have the same number of tokens"); - } - if (positions_ndim == 2) { - TORCH_CHECK( - query.size(0) == positions.size(0) && - key.size(0) == positions.size(0) && - query.size(1) == positions.size(1) && - key.size(1) == positions.size(1), - "query, key and positions must have the same batch_size and seq_len"); - } - TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32"); - int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.numel() / num_tokens; - TORCH_CHECK(query_hidden_size % head_size == 0); - TORCH_CHECK(key_hidden_size % head_size == 0); - TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend"); - - // Make sure query and key have consistent number of heads - int num_heads = query_hidden_size / head_size; - int num_kv_heads = key_hidden_size / head_size; - TORCH_CHECK(num_heads % num_kv_heads == 0); - at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options()); - at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options()); - - int rot_dim = cos_sin_cache.size(1); - int seq_dim_idx = positions_ndim - 1; - int64_t *position_ids_ptr = positions.data_ptr(); - void *query_dst_ptr = query_dst.data_ptr(); - void *key_dst_ptr = key_dst.data_ptr(); - void *query_ptr = query.data_ptr(); - void *key_ptr = key.data_ptr(); - void *cos_sin_cache_ptr = cos_sin_cache.data_ptr(); - int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.stride(seq_dim_idx); - int64_t dst_query_stride = query_dst.stride(0); - int64_t dst_key_stride = key_dst.stride(0); - at::ScalarType scalar_type = query.scalar_type(); - aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); - at_npu::native::OpCommand cmd; - cmd.Name("rotary_embedding"); - cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, - query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, - dst_query_stride, dst_key_stride, num_heads, num_kv_heads, head_size]() -> int { - auto dtype_num = get_dtype_from_torch(scalar_type); - int device_id = 0; - int64_t aiv_num = 0; - TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS); - uint32_t loop_cnt = (num_tokens + aiv_num - 1) / aiv_num; - rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr, - key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride, - dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aiv_num); - return 0; - }); - cmd.Run(); - return {query_dst, key_dst}; -} - std::tuple mla_preprocess( const at::Tensor &hiddenState, const at::Tensor &wdqkv, const c10::optional &descale0, const at::Tensor &gamma1, const c10::optional &beta1, const at::Tensor &wuq, @@ -1314,14 +1245,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor); - // Rotary embedding - // Apply GPT-NeoX style rotary embedding to query and key. - ops.def( - "rotary_embedding(Tensor positions, Tensor! query," - " Tensor! key, int head_size," - " Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)"); - ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding); - ops.def( "get_masked_input_and_mask(Tensor input, " " int org_vocab_start_index, " diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index b19fc64379c..84eeb99f773 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -36,24 +36,6 @@ namespace vllm_ascend { namespace meta { const int64_t INT4_NUMS_IN_INT32 = 8; -std::tuple rotary_embedding_meta( - at::Tensor &positions, - at::Tensor &query, - at::Tensor &key, - int64_t head_size, - at::Tensor &cos_sin_cache, - bool is_neox) { - auto num_tokens = positions.sym_numel(); - auto query_hidden_size = query.sym_numel() / num_tokens; - auto key_hidden_size = key.sym_numel() / num_tokens; - - auto num_heads = query_hidden_size / head_size; - auto num_kv_heads = key_hidden_size / head_size; - at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options()); - at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options()); - - return {query_dst, key_dst}; -} std::tuple get_masked_input_and_mask_meta( at::Tensor &input, @@ -442,8 +424,6 @@ namespace { // the custom kernel been captured into aclgraph TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { - // Rotary embedding meta implementation - ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); // Masked input and mask meta implementation ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta); // Bgmv expand diff --git a/tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py b/tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py index b2c7f2aba0e..c2e8e64d944 100644 --- a/tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py +++ b/tests/e2e/multicard/2-cards/spec_decode/test_spec_decode.py @@ -41,7 +41,7 @@ # NOTE: golden may change (eagle_proposer only runs in eager mode currently), # thus please update it if ci fails but you have better acceptance BASELINES_SP = { - "eagle3": [0.68, 0.40, 0.18], + "eagle3": [0.7477477477477478, 0.4294294294294294, 0.21621621621621623], } diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_rotary_embedding.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_rotary_embedding.py deleted file mode 100644 index 27e9b3b9eae..00000000000 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_rotary_embedding.py +++ /dev/null @@ -1,351 +0,0 @@ -# Copyright 2023 The vLLM team. - -# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. -# Adapted from -# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py - -import gc -from typing import Optional, Tuple, Union - -import pytest -import torch -import torch.nn as nn - -from vllm_ascend.utils import enable_custom_op - -enable_custom_op() - -# Only Neox style true scenario is supported for now -IS_NEOX_STYLE = [True] -DTYPES = [torch.half] -HEAD_SIZES = [64, 64, 96, 128, 256] -ROTARY_DIMS = [None, 32] # None means rotary dim == head size -NUM_HEADS = [17] # Arbitrary values for testing -BATCH_SIZES = [5] # Arbitrary values for testing -SEQ_LENS = [11, 4096] # Arbitrary values for testing -NUM_TOKENS = [10, 21] -SEEDS = [0] -DEVICES = [f"npu:{0}"] -# Set tolerance to 1 for quant ops -DEFAULT_ATOL = 1e-3 -DEFAULT_RTOL = 1e-3 - - -def _apply_rotary_emb( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, -) -> torch.Tensor: - """ - Args: - x: [num_tokens, num_heads, head_size] - cos: [num_tokens, head_size // 2] - sin: [num_tokens, head_size // 2] - is_neox_style: Whether to use the Neox-style or GPT-J-style rotary - positional embeddings. - """ - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - x1 = x[..., ::2] - x2 = x[..., 1::2] - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - -# adapted from https://github.com/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding.py -class RotaryEmbedding(nn.Module): - """Original rotary positional embedding.""" - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - dtype: torch.dtype, - ) -> None: - super().__init__() - self.head_size = head_size - self.rotary_dim = rotary_dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.is_neox_style = is_neox_style - self.dtype = dtype - - cache = self._compute_cos_sin_cache() - cache = cache.to(dtype) - self.cos_sin_cache: torch.Tensor - self.register_buffer("cos_sin_cache", cache, persistent=False) - - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: - """Compute the inverse frequency.""" - # NOTE(woosuk): To exactly match the HF implementation, we need to - # use CPU to compute the cache and then move it to GPU. However, we - # create the cache on GPU for faster initialization. This may cause - # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - """Compute the cos and sin cache.""" - inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) - - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward_native( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """A PyTorch-native implementation of forward().""" - if offsets is not None: - positions = positions + offsets - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key - - -# test with leading dimension and merge seqlen and batch_size as num_tokens -@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("seq_len", SEQ_LENS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", DEVICES) -@torch.inference_mode() -def test_rotary_embedding_quant_with_leading_dim( - is_neox_style: bool, - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: Optional[int], - dtype: torch.dtype, - seed: int, - device: str, - max_position: int = 8192, - base: int = 10000, -) -> None: - if rotary_dim is None: - rotary_dim = head_size - - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - rope = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style, dtype) - rope = rope.to(dtype=dtype) - num_tokens = batch_size * seq_len - positions = torch.randint(0, max_position, (batch_size * seq_len, )) - qkv_tensor = torch.randn(num_tokens, - num_heads * head_size * 3, - dtype=dtype) - query, key, _ = qkv_tensor.split( - [num_heads * head_size, num_heads * head_size, num_heads * head_size], - dim=-1, - ) - - ref_query, ref_key = rope.forward_native(positions, query, key) - query, key = torch.ops._C_ascend.rotary_embedding( - positions, - query, - key, - rope.head_size, - rope.cos_sin_cache, - rope.is_neox_style, - ) - - # Compare the results. - torch.testing.assert_close(query.view(ref_query.size()), - ref_query, - atol=DEFAULT_ATOL, - rtol=DEFAULT_RTOL) - torch.testing.assert_close(key.view(ref_key.size()), - ref_key, - atol=DEFAULT_ATOL, - rtol=DEFAULT_RTOL) - gc.collect() - torch.npu.empty_cache() - torch.npu.reset_peak_memory_stats() - - -class ModelwithRotaryEmbedding(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - dtype: torch.dtype, - ) -> None: - super().__init__() - self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3) - self.rope = RotaryEmbedding( - head_size=head_size, - rotary_dim=rotary_dim, - max_position_embeddings=max_position_embeddings, - base=base, - is_neox_style=is_neox_style, - dtype=dtype, - ) - self.o_proj = nn.Linear(num_heads * head_size, hidden_size) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph - qkv = self.qkv_proj(hidden_states) - q, k, v = qkv.chunk(3, dim=-1) - query, key = torch.ops._C_ascend.rotary_embedding( - positions, - q, - k, - self.rope.head_size, - self.rope.cos_sin_cache, - self.rope.is_neox_style, - ) - query = query.view(q.shape) - key = key.view(k.shape) - o = self.o_proj(query) - return o - - -# The first graph seems will have some accuracy issue when directly run pytest on the ops folder, -# add a warmup graph replay for workaround -ACL_GRPAH_FIRST_RUN = True - - -@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) -@pytest.mark.parametrize("num_tokens", BATCH_SIZES) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", DEVICES) -@torch.inference_mode() -def test_capture_rotary_embedding_in_aclgraph( - is_neox_style: bool, - num_tokens: int, - num_heads: int, - head_size: int, - rotary_dim: int, - dtype: torch.dtype, - seed: int, - device: str, - max_position_embeddings: int = 8192, - base: int = 10000, -): - """Test if the rotary embedding can be captured in aclgraph.""" - torch.manual_seed(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - model = ModelwithRotaryEmbedding( - hidden_size=num_heads * head_size, - num_heads=num_heads, - head_size=head_size, - rotary_dim=rotary_dim, - max_position_embeddings=max_position_embeddings, - base=base, - is_neox_style=is_neox_style, - dtype=dtype, - ) - - def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input): - # Validate if the rotary_embedding custom kernel is indeed inside the graph by - # string match - graph = str(gm.graph) - assert "_C_ascend.rotary_embedding" in graph - return gm - - static_positions = torch.randint(0, max_position_embeddings, - (num_tokens, )) - static_hidden_states = torch.randn(num_tokens, - num_heads * head_size, - dtype=dtype, - device="npu") - compiled_model = torch.compile(model, backend=custom_op_checking_backend) - stream = torch.npu.Stream() - stream.wait_stream(torch.npu.current_stream()) - with torch.npu.stream(stream): - # warmup the fx graph before capture - for i in range(3): - static_output = compiled_model(static_positions, - static_hidden_states, - offsets=None) - stream.wait_stream(torch.npu.current_stream()) - - aclgraph = torch.npu.NPUGraph() - - with torch.npu.graph(aclgraph): - # Capture the model in aclgraph. - static_output = compiled_model(static_positions, static_hidden_states) - # Capture the model in aclgraph. - random_filled_positions = torch.randint(0, - max_position_embeddings, - (num_tokens, ), - device="npu") - random_filled_hidden_states = torch.randn(num_tokens, - num_heads * head_size, - dtype=dtype, - device="npu") - static_positions.copy_(random_filled_positions) - static_hidden_states.copy_(random_filled_hidden_states) - - aclgraph.replay() - global ACL_GRPAH_FIRST_RUN - if ACL_GRPAH_FIRST_RUN: - ACL_GRPAH_FIRST_RUN = False - return - output_reference = model(static_positions, static_hidden_states) - torch.testing.assert_close(static_output, - output_reference, - atol=DEFAULT_ATOL, - rtol=DEFAULT_RTOL) - gc.collect() - torch.npu.empty_cache() - torch.npu.reset_peak_memory_stats() diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index 36f13f7dde5..f4a6755e49d 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -24,10 +24,10 @@ model="Qwen/Qwen3-0.6B", prompts=PROMPTS_SHORT, golden_answers=[ - " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I want to know if there are any", + " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I'm looking for a job in the", ' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president', ' Paris. The capital of France is also the capital of the Republic of France. The capital of France is also the capital of the European Union. The capital of', - ' not just a technological frontier but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and' + ' not just a technological challenge but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and' ], ) @@ -48,8 +48,8 @@ prompts=PROMPTS_LONG, golden_answers=[ ' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the', - " \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area", - ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations $x^2 +' + " \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over", + ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can' ]) CASE_DS_FULL_DECODE_ONLY = LLMTestCase( diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py deleted file mode 100644 index 567c15d9325..00000000000 --- a/tests/ut/ops/test_rotary_embedding.py +++ /dev/null @@ -1,453 +0,0 @@ -import math -import unittest -from unittest.mock import MagicMock, PropertyMock, patch - -import torch -from transformers.configuration_utils import PretrainedConfig -from vllm.config import ModelConfig, VllmConfig -from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding) -from vllm.platforms import CpuArchEnum - -from tests.ut.base import TestBase -from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled -from vllm_ascend.utils import AscendDeviceType - -MODEL = "Qwen3-0.6B" -MODEL_VL = "Qwen/Qwen2.5-VL-3B-Instruct" -MAX_NUM_BATCHED_TOKEND = 10000 - - -class TestCustomRotaryEmbeddingEnabled(unittest.TestCase): - - def setUp(self): - # Common setup for tests - self.positions = torch.tensor([1, 2, 3]) - self.query = torch.randn(3, 4, dtype=torch.float16) - self.key = torch.randn(3, 4, dtype=torch.float16) - self.head_size = 32 - self.cos_sin_cache = torch.randn(3, 4) - - # Mock self object for rope_forward_oot - self.mock_self = MagicMock() - self.mock_self.head_size = self.head_size - self.mock_self.cos_sin_cache = self.cos_sin_cache - self.mock_self.is_neox_style = True - self.mock_self.forward_native.return_value = (self.query, self.key) - - def test_custom_rotary_embedding_enabled(self): - # Test when all conditions are True - with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', - return_value=True): - result = _custom_rotary_embedding_enabled(self.query, True, - self.head_size) - self.assertTrue(result) - - # Test when dtype is not float16 - with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', - return_value=True): - query = self.query.to(torch.float32) - result = _custom_rotary_embedding_enabled(query, True, - self.head_size) - self.assertFalse(result) - - # Test when neox_style is False - with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', - return_value=True): - result = _custom_rotary_embedding_enabled(self.query, False, - self.head_size) - self.assertFalse(result) - - # Test when head_size is not divisible by 32 - with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', - return_value=True): - result = _custom_rotary_embedding_enabled(self.query, True, - self.head_size + 1) - self.assertFalse(result) - - # Test when custom op is disabled - with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', - return_value=False): - result = _custom_rotary_embedding_enabled(self.query, True, - self.head_size) - self.assertFalse(result) - - -class TestAscendRotaryEmbedding(unittest.TestCase): - - def setUp(self): - # Common setup for tests - self.positions = torch.tensor([1, 2, 3]) - self.query = torch.randn(3, 1, 32, dtype=torch.float16) - self.key = torch.randn(3, 1, 32, dtype=torch.float16) - self.head_size = 32 - self.rotary_dim = self.head_size - self.max_position = 16 - self.rope_theta = 10000 - self.is_neox_style = True - self.cos_sin_cache = torch.randn(3, 1, 32) - self.layer = RotaryEmbedding(self.head_size, self.rotary_dim, - self.max_position, self.rope_theta, - self.is_neox_style, torch.float16) - - # Mock self object for rope_forward_oot - self.mock_self = MagicMock() - self.mock_self.head_size = self.head_size - self.mock_self.cos_sin_cache = self.cos_sin_cache - self.mock_self.is_neox_style = self.is_neox_style - - @patch('torch.ops._C_ascend') - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType.A3) - @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', - return_value=True) - @patch('torch.ops._npu_rotary_embedding') - @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) - @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) - @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) - @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) - def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding, - mock_custom_enabled, - mock_soc_version, mock__c): - mock__c.rotary_embedding.return_value = self.query, self.key - vllm_config = VllmConfig() - model_config = ModelConfig(MODEL, - tokenizer=MODEL, - max_model_len=MAX_NUM_BATCHED_TOKEND) - model_config.hf_text_config = PretrainedConfig() - vllm_config.model_config = model_config - with set_ascend_forward_context(None, vllm_config): - result_q, result_k = self.layer.forward(self.positions, self.query, - self.key) - - mock__c.rotary_embedding.assert_called_once() - self.assertEqual(result_q.shape, self.query.shape) - self.assertEqual(result_k.shape, self.key.shape) - - @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', - return_value=False) - @patch('torch_npu._npu_rotary_embedding') - @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) - @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) - @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) - @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) - def test_rope_forward_oot_contiguous(self, mock_npu_rotary, - mock_custom_enabled): - # Test contiguous path when custom is disabled - non_contig_query = self.query.transpose(0, 1) - non_contig_key = self.key.transpose(0, 1) - vllm_config = VllmConfig() - model_config = ModelConfig(MODEL, - tokenizer=MODEL, - max_model_len=MAX_NUM_BATCHED_TOKEND) - model_config.hf_text_config = PretrainedConfig() - vllm_config.model_config = model_config - with set_ascend_forward_context(None, vllm_config): - result_q, result_k = self.layer.forward(self.positions, - non_contig_query, - non_contig_key) - - mock_npu_rotary.assert_called_once() - self.assertEqual(result_q.shape, non_contig_query.shape) - self.assertEqual(result_k.shape, non_contig_key.shape) - - @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) - @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) - @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) - @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) - def test_rope_forward_oot_with_offsets(self): - # Test that NotImplementedError is raised when offsets is provided - offsets = torch.tensor([1, 2, 3]) - with self.assertRaises(NotImplementedError): - vllm_config = VllmConfig() - model_config = ModelConfig(MODEL, - tokenizer=MODEL, - max_model_len=MAX_NUM_BATCHED_TOKEND) - model_config.hf_text_config = PretrainedConfig() - vllm_config.model_config = model_config - with set_ascend_forward_context(None, vllm_config): - self.layer.forward(self.positions, self.query, self.key, - offsets) - - @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', - return_value=False) - @patch('torch_npu._npu_rotary_embedding') - @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) - @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) - @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) - @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) - def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary, - mock_custom_enabled): - # Test neox_style override - vllm_config = VllmConfig() - model_config = ModelConfig(MODEL, - tokenizer=MODEL, - max_model_len=MAX_NUM_BATCHED_TOKEND) - model_config.hf_text_config = PretrainedConfig() - vllm_config.model_config = model_config - with set_ascend_forward_context(None, vllm_config): - result_q, result_k = self.layer.forward( - self.positions, - self.query, - self.key, - is_neox_style_override=False) - # Check that neox_style=False was passed to the NPU function - args, kwargs = mock_npu_rotary.call_args - self.assertFalse(args[-1]) - - @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', - return_value=False) - @patch('torch_npu._npu_rotary_embedding') - @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) - @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) - @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) - @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) - def test_rope_forward_oot_rotary_dim_less_than_head_size( - self, mock_npu_rotary, mock_custom_enabled): - # test case when rotary_dim < head_size - org_rotary_dim = self.layer.rotary_dim - self.layer.rotary_dim = self.layer.head_size // 2 - - vllm_config = VllmConfig() - model_config = ModelConfig(MODEL, - tokenizer=MODEL, - max_model_len=MAX_NUM_BATCHED_TOKEND) - model_config.hf_text_config = PretrainedConfig() - vllm_config.model_config = model_config - with set_ascend_forward_context(None, vllm_config): - result_q, result_k = self.layer.forward(self.positions, self.query, - self.key) - - mock_npu_rotary.assert_called_once() - self.assertEqual(result_q.shape, self.query.shape) - self.assertEqual(result_k.shape, self.key.shape) - - # restore rotary_dim - self.layer.rotary_dim = org_rotary_dim - - -class MockRopeModule: - - def __init__(self, max_seq_len=2048, is_neox_style=True): - self.max_seq_len = max_seq_len - self.is_neox_style = is_neox_style - self.cos_cached = None - self.sin_cached = None - self.rotary_dim = 1 - self.base = 1 - - -class TestAscendDeepseekScalingRotaryEmbedding(TestBase): - - def setUp(self): - # Common setup for tests - self.positions = torch.tensor([1, 2, 3]) - self.query = torch.randn(3, 1, 32, dtype=torch.float16) - self.key = torch.randn(3, 1, 32, dtype=torch.float16) - self.head_size = 32 - self.rotary_dim = self.head_size - self.max_position = 16 - self.rope_theta = 10000 - self.is_neox_style = True - self.scaling_factor = 1 - self.layer = None - - def _create_layer(self): - self.layer = DeepseekScalingRotaryEmbedding( - self.head_size, self.rotary_dim, self.max_position, - self.rope_theta, self.is_neox_style, self.scaling_factor, - torch.float16) - return self.layer - - @patch("vllm.platforms.current_platform.device_type", - new=torch.device("cpu")) - @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", - new_callable=PropertyMock) - def test_native_rope_deepseek_forward_base(self, mock_npuplatform): - mock_npuplatform.device_type = torch.device("cpu") - self.layer = self._create_layer() - with patch("vllm_ascend.ops.rotary_embedding._rope_forward_oot", - return_value=(self.query, - self.key)) as mock_rope_forward_oot: - q_pe, k_pe = self.layer.forward(self.positions, self.query, - self.key) - mock_rope_forward_oot.assert_called_once() - assert q_pe.shape == self.query.shape - assert k_pe.shape == self.key.shape - - @patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot') - @patch("vllm.platforms.current_platform.device_type", - new=torch.device("cpu")) - @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", - new_callable=PropertyMock) - def test_native_rope_deepseek_forward_key_reshaping( - self, mock_npuplatform, mock_rope_forward_oot): - mock_npuplatform.device_type = torch.device("cpu") - self.layer = self._create_layer() - - key = torch.randn(1, 32) - - mock_rope_forward_oot.return_value = (self.query, key) - - q_pe, k_pe = self.layer.forward(self.positions, self.query, key) - mock_rope_forward_oot.assert_called_once() - assert q_pe.shape == self.query.shape - assert k_pe.shape == key.shape - - @patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot') - @patch("vllm.platforms.current_platform.device_type", - new=torch.device("cpu")) - @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", - new_callable=PropertyMock) - def test_native_rope_deepseek_forward_non_neox_style( - self, mock_npuplatform, mock_rope_forward_oot): - mock_npuplatform.device_type = torch.device("cpu") - self.layer = self._create_layer() - - mock_rope_forward_oot.return_value = (self.query, self.key) - - q_pe, k_pe = self.layer.forward(self.positions, self.query, self.key) - - mock_rope_forward_oot.assert_called_once() - assert q_pe.shape == self.query.shape - assert k_pe.shape == self.key.shape - - @patch("vllm.platforms.current_platform.device_type", - new=torch.device("cpu")) - @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", - new_callable=PropertyMock) - def test_basic_case(self, mock_npuplatform): - # Test with standard values - mock_npuplatform.device_type = torch.device("cpu") - self.layer = self._create_layer() - num_rotations = 100 - dim = 512 - base = 10000 - max_position_embeddings = 2048 - - result = self.layer._yarn_find_correction_dim(num_rotations, dim, base, - max_position_embeddings) - - # Calculate expected value manually - expected = (dim * torch.log( - torch.tensor(max_position_embeddings) / - (num_rotations * 2 * torch.pi))) / (2 * - torch.log(torch.tensor(base))) - - self.assertTrue(torch.allclose(result, expected)) - - @patch("vllm.platforms.current_platform.device_type", - new=torch.device("cpu")) - @patch("vllm_ascend.ops.rotary_embedding.NPUPlatform", - new_callable=PropertyMock) - def test_yarn_get_mscale(self, mock_npuplatform): - mock_npuplatform.device_type = torch.device("cpu") - self.layer = self._create_layer() - - # test_scale_less_than_or_equal_1 - self.assertEqual(self.layer._yarn_get_mscale(scale=0.5), 1.0) - self.assertEqual(self.layer._yarn_get_mscale(scale=1.0), 1.0) - self.assertEqual(self.layer._yarn_get_mscale(scale=0.999), 1.0) - - # test_scale_greater_than_1: - test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)), - (10.0, 1.0, 1.0 + 0.1 * math.log(10.0)), - (5.0, 2.0, 1.0 + 0.2 * math.log(5.0)), - (math.e, 1.0, 1.0 + 0.1)] - - for scale, mscale, expected in test_cases: - result = self.layer._yarn_get_mscale(scale, mscale) - self.assertAlmostEqual( - result, - expected, - places=6, - msg=f"Failed for scale={scale}, mscale={mscale}") - - -class TestAscendMRotaryEmbedding(unittest.TestCase): - - def setUp(self): - # Common setup for tests - self.number_tokens = 3 - self.num_head = 8 - self.num_kvhead = 8 - self.head_size = 128 - self.max_position_embeddings = 128000 - self.is_neox_style = True - self.rope_theta = 1000000.0 - self.positions_1d = torch.tensor([1, 2, 3]) - self.positions_2d = torch.randint(1, 10, (3, self.number_tokens)) - - self.query = torch.randn( - (self.number_tokens, self.num_head * self.head_size), - dtype=torch.bfloat16) - self.key = torch.randn( - (self.number_tokens, self.num_kvhead * self.head_size), - dtype=torch.bfloat16) - - # Qwen2.5-VL mrope section case - self.mrope_section = [16, 24, 24] - - self.layer = MRotaryEmbedding(self.head_size, - self.head_size, - self.max_position_embeddings, - base=self.rope_theta, - is_neox_style=self.is_neox_style, - dtype=torch.bfloat16, - mrope_section=self.mrope_section) - - self.mock_config = MagicMock() - - def _create_vllm_config(self): - vllm_config = VllmConfig() - model_config = ModelConfig(MODEL_VL, - tokenizer=MODEL_VL, - max_model_len=MAX_NUM_BATCHED_TOKEND) - model_config.hf_text_config = PretrainedConfig() - vllm_config.model_config = model_config - return vllm_config - - @patch('torch_npu.npu_mrope') - @patch('vllm_ascend.platform.NPUPlatform.get_cpu_architecture') - @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) - @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) - @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) - @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) - def test_forward_oot_1d_positions(self, mock_cpu_arc, mock_npu_mrope): - mock_cpu_arc.return_value = CpuArchEnum.ARM - - mock_npu_mrope.return_value = (torch.zeros_like(self.query), - torch.zeros_like(self.key)) - - vllm_config = self._create_vllm_config() - with set_ascend_forward_context(None, vllm_config): - result_q, result_k = self.layer.forward_oot( - self.positions_1d, self.query, self.key) - - mock_npu_mrope.assert_called_once() - self.assertFalse(torch.isnan(result_q).any().item()) - self.assertFalse(torch.isnan(result_k).any().item()) - self.assertEqual(result_q.shape, self.query.shape) - - @patch('torch_npu.npu_mrope') - @patch('vllm_ascend.platform.NPUPlatform.get_cpu_architecture') - @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) - @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) - @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) - @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) - def test_forward_oot_2d_positions(self, mock_cpu_arc, mock_npu_mrope): - mock_cpu_arc.return_value = CpuArchEnum.ARM - - mock_npu_mrope.return_value = (torch.zeros_like(self.query), - torch.zeros_like(self.key)) - - vllm_config = self._create_vllm_config() - with set_ascend_forward_context(None, vllm_config): - result_q, result_k = self.layer.forward_oot( - self.positions_2d, self.query, self.key) - - mock_npu_mrope.assert_called_once() - self.assertFalse(torch.isnan(result_q).any().item()) - self.assertFalse(torch.isnan(result_k).any().item()) - self.assertEqual(result_q.shape, self.query.shape) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index e16cc0b532f..5f093e184a4 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -166,8 +166,6 @@ def set_ascend_forward_context( _mc2_tokens_capacity: Optional[int] = None _reserved_mc2_mask: Optional[torch.Tensor] = None -_sin: Optional[torch.Tensor] = None -_cos: Optional[torch.Tensor] = None def set_mc2_tokens_capacity(vllm_config, max_num_reqs, diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index ed90c7f8690..8090ef19cb0 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -45,6 +45,7 @@ def __init__(self, def get_inputs(self): T = 5 + max_position_embeddings = 16384 qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, @@ -55,25 +56,22 @@ def get_inputs(self): k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - cos = torch.empty(1, - T, - 1, - self.head_dim, - dtype=torch.bfloat16, - device="npu") - sin = torch.empty(1, - T, - 1, - self.head_dim, - dtype=torch.bfloat16, - device="npu") - return [qkv, q_weight, k_weight, cos, sin] + cos_sin_cache = torch.empty(max_position_embeddings, + self.head_dim, + dtype=torch.bfloat16, + device="npu") + positions = torch.ones(T, dtype=torch.int64, device="npu") + return [qkv, q_weight, k_weight, cos_sin_cache, positions] def register(self, pm_pass: PatternMatcherPass): - def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, - k_weight: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor): + def pattern( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, + ): q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -89,21 +87,21 @@ def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, self.eps) q_flat = q_norm_out.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, - self.head_dim) k_flat = k_norm_out.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, - self.head_dim) - - q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb( - q_reshape, k_reshape, cos, sin) + q_rope, k_rope = torch.ops.vllm.rope_forward_oot( + positions, q_flat, k_flat, cos_sin_cache, self.head_dim, + self.head_dim, True) return q_rope, k_rope, v - def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, - k_weight: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor): + def replacement( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, + ): results = torch.ops.vllm.qkv_rmsnorm_rope( input=qkv, q_weight=q_weight, @@ -114,9 +112,9 @@ def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, eps=self.eps, q_bias=None, k_bias=None, - sin=sin, - cos=cos) - + cos_sin_cache=cos_sin_cache, + positions=positions, + ) return results pm.register_replacement(pattern, replacement, self.get_inputs(), @@ -142,6 +140,7 @@ def __init__(self, def get_inputs(self): T = 5 + max_position_embeddings = 16384 qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, @@ -154,27 +153,27 @@ def get_inputs(self): device="npu") q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - cos = torch.empty(1, - T, - 1, - self.head_dim, - dtype=torch.bfloat16, - device="npu") - sin = torch.empty(1, - T, - 1, - self.head_dim, - dtype=torch.bfloat16, - device="npu") + cos_sin_cache = torch.empty(max_position_embeddings, + self.head_dim, + dtype=torch.bfloat16, + device="npu") + positions = torch.ones(T, dtype=torch.int64, device="npu") - return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin] + return [ + qkv, q_weight, k_weight, q_bias, k_bias, cos_sin_cache, positions + ] def register(self, pm_pass: PatternMatcherPass): - def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, - k_weight: torch.Tensor, q_bias: torch.Tensor, - k_bias: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor): + def pattern( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_bias: torch.Tensor, + k_bias: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, + ): q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -191,22 +190,23 @@ def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, k_normed = k_norm_out + k_bias q_flat = q_normed.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, - self.head_dim) k_flat = k_normed.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, - self.head_dim) - - q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb( - q_reshape, k_reshape, cos, sin) + q_rope, k_rope = torch.ops.vllm.rope_forward_oot( + positions, q_flat, k_flat, cos_sin_cache, self.head_dim, + self.head_dim, True) return q_rope, k_rope, v - def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, - k_weight: torch.Tensor, q_bias: torch.Tensor, - k_bias: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor): + def replacement( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_bias: torch.Tensor, + k_bias: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, + ): results = torch.ops.vllm.qkv_rmsnorm_rope( input=qkv, q_weight=q_weight, @@ -217,8 +217,9 @@ def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, eps=self.eps, q_bias=q_bias, k_bias=k_bias, - cos=cos, - sin=sin) + cos_sin_cache=cos_sin_cache, + positions=positions, + ) return results pm.register_replacement(pattern, replacement, self.get_inputs(), diff --git a/vllm_ascend/meta_registration.py b/vllm_ascend/meta_registration.py index 9a58afd9c87..18f7113159e 100644 --- a/vllm_ascend/meta_registration.py +++ b/vllm_ascend/meta_registration.py @@ -31,7 +31,7 @@ # 3. The registration utility will check if a meta implementation already exists for your op, # and only register if necessary. This avoids duplicate registrations. # -# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask. +# 4. Example meta implementations are provided below for get_masked_input_and_mask. # # 5. When developing new custom ops, always provide a meta implementation to enable tracing, # export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile` @@ -53,21 +53,6 @@ def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""): lib.impl(op_name, fn, "Meta") -def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool): - - num_tokens = positions.numel() - query_hidden_size = query.numel() // num_tokens - key_hidden_size = key.numel() // num_tokens - num_heads = query_hidden_size // head_size - num_kv_heads = key_hidden_size // head_size - - query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size) - key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size) - return query_dst, key_dst - - def get_masked_input_and_mask_meta(input: torch.Tensor, org_vocab_start_index: int, org_vocab_end_index: int, @@ -97,8 +82,6 @@ def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, return y_out -register_meta_if_necessary("_C_ascend", "rotary_embedding", - rotary_embedding_meta) register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", get_masked_input_and_mask_meta) register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta) diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index d91741ae3f1..0a08f9cc2fe 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -12,6 +12,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ops.rotary_embedding import rope_forward_oot from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.utils import npu_stream_switch, prefetch_stream @@ -312,6 +313,17 @@ def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor, input_offset, torch.qint8, -1, False) +def _rope_forward_oot_impl_fake( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + cos_sin_cache: torch.Tensor, + head_dim: int, + rotary_dim: int, + is_neox_style: bool = True) -> tuple[torch.Tensor, torch.Tensor]: + return query, key + + direct_register_custom_op(op_name="maybe_chunk_residual", op_func=_maybe_chunk_residual_impl, fake_impl=lambda x, residual: torch.empty_like(x), @@ -377,3 +389,9 @@ def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor, fake_impl=_quantize_impl_fake, mutates_args=[], dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="rope_forward_oot", + op_func=rope_forward_oot, + fake_impl=_rope_forward_oot_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index bd1f925d6d4..79bfa0a2549 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -27,12 +27,14 @@ from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb from vllm.triton_utils import HAS_TRITON +from vllm_ascend.platform import NPUPlatform +from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, + has_rope, is_vl_model) + if HAS_TRITON: from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope -from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, - get_ascend_device_type, has_rope, is_vl_model) + from vllm_ascend.ops.triton.rope import rope_forward_triton # Currently, rope ops used on npu requires detached cos && sin as inputs. # However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable. @@ -169,76 +171,55 @@ def get_cos_and_sin_slice(): return _cos_slice, _sin_slice -def _custom_rotary_embedding_enabled(query, neox_style, head_size): - return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op( - ) - - -def _rope_forward_oot( - self, +def rope_forward_oot( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, + cos_sin_cache: torch.Tensor, + head_size: int, + rotary_dim: int, is_neox_style: bool, offsets: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: query_shape, key_shape = query.shape, key.shape - if self.cos_sin_cache.device != query.device: - self.cos_sin_cache = self.cos_sin_cache.to(query.device) - if self.cos_sin_cache.dtype != query.dtype: - self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) - # adopt custom kernel path for rotary_embedding - if _custom_rotary_embedding_enabled( - query, is_neox_style, self.head_size) and get_ascend_device_type( - ) != AscendDeviceType._310P: - query, key = torch.ops._C_ascend.rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - is_neox_style, - ) - return query.view(query_shape), key.view(key_shape) if offsets is not None: raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") + if HAS_TRITON: + num_tokens = query.shape[0] + query, key = rope_forward_triton( + query.view(num_tokens, -1, head_size), + key.view(num_tokens, -1, head_size), + cos_sin_cache=cos_sin_cache, + positions=positions, + rope_dim=rotary_dim, + is_neox_style=is_neox_style, + ) else: - cos, sin = get_cos_and_sin_slice() - if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ - -1] == 128 and cos is not None and sin is not None: - # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. - # This method requires head_size and rotary_dim equal 128 and neox_style is True - query = query.contiguous().view(1, query.shape[0], -1, - self.head_size) - key = key.contiguous().view(1, key.shape[0], -1, self.head_size) - # Although this function modifies in-place, please retain the function's return value. - # Otherwise, the graph fusion operation may fail. - query, key = torch_npu.npu_apply_rotary_pos_emb( - query, key, cos, sin) - elif self.rotary_dim < self.head_size: + if rotary_dim < head_size: num_tokens = query.shape[0] - query = query.view(num_tokens, -1, self.head_size) - key = key.view(num_tokens, -1, self.head_size) - q_rot = query[..., :self.rotary_dim] - q_pass = query[..., self.rotary_dim:] - k_rot = key[..., :self.rotary_dim] - k_pass = key[..., self.rotary_dim:] + query = query.view(num_tokens, -1, head_size) + key = key.view(num_tokens, -1, head_size) + q_rot = query[..., :rotary_dim] + q_pass = query[..., rotary_dim:] + k_rot = key[..., :rotary_dim] + k_pass = key[..., rotary_dim:] q_rot = q_rot.contiguous().view(num_tokens, -1) k_rot = k_rot.contiguous().view(num_tokens, -1) + # only the rotary part is processed here, + # the dimension should be rotary_dim torch_npu._npu_rotary_embedding( positions, q_rot, k_rot, - self.rotary_dim, - self.cos_sin_cache, + rotary_dim, + cos_sin_cache, is_neox_style, ) - q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) - k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) - q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) - k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) - return q, k + q_rot = q_rot.view(num_tokens, -1, rotary_dim) + k_rot = k_rot.view(num_tokens, -1, rotary_dim) + query = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) + key = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) else: # TODO: Remove the contiguous in the future. query = query.contiguous().view(query.shape[0], -1) @@ -247,11 +228,11 @@ def _rope_forward_oot( positions, query, key, - self.head_size, - self.cos_sin_cache, + head_size, + cos_sin_cache, is_neox_style, ) - return query.view(query_shape), key.view(key_shape) + return query.view(query_shape), key.view(key_shape) class AscendRotaryEmbedding(RotaryEmbedding): @@ -281,8 +262,10 @@ def forward_oot( is_neox_style = self.is_neox_style if is_neox_style_override is not None: is_neox_style = is_neox_style_override - return _rope_forward_oot(self, positions, query, key, is_neox_style, - offsets) + return torch.ops.vllm.rope_forward_oot(positions, query, key, + self.cos_sin_cache, + self.head_size, self.rotary_dim, + is_neox_style) class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding): @@ -524,8 +507,11 @@ def forward(self, b, h_k, d = key.shape key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) - q_pe, k_pe = _rope_forward_oot(self, positions, query, key, - is_neox_style, offsets) + q_pe, k_pe = torch.ops.vllm.rope_forward_oot(positions, query, key, + self.cos_sin_cache, + self.head_size, + self.rotary_dim, + is_neox_style) return q_pe, k_pe diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index 6bc2d373249..01ff79bf26b 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -27,8 +27,8 @@ @triton.jit def split_qkv_rmsnorm_rope_kernel( input_ptr, - sin_ptr, - cos_ptr, + cos_sin_ptr, + pos_ptr, q_ptr, k_ptr, v_ptr, @@ -78,9 +78,11 @@ def split_qkv_rmsnorm_rope_kernel( normalized_values = (normalized_values * weight_values).to( tl.bfloat16) - sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM) - sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM) - cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM) + pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) + cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM) + sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) + cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) + sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) x1 = tl.extract_slice( normalized_values, offsets=(0, 0), @@ -93,23 +95,25 @@ def split_qkv_rmsnorm_rope_kernel( sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), - dtype=tl.bfloat16) - cat_x = tl.insert_slice( - cat_x, - -x2, + roped_q1 = x1 * cos - x2 * sin + roped_q2 = x2 * cos + x1 * sin + + roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), + dtype=tl.bfloat16) + roped_q = tl.insert_slice( + roped_q, + roped_q1, offsets=(0, 0), sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - cat_x = tl.insert_slice( - cat_x, - x1, + roped_q = tl.insert_slice( + roped_q, + roped_q2, offsets=(0, HALF_HEAD_DIM), sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - roped_q = cat_x * sin + normalized_values * cos tl.store( q_ptr + output_offset + col_indices, roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty), @@ -143,9 +147,12 @@ def split_qkv_rmsnorm_rope_kernel( else: normalized_values = (normalized_values * weight_values).to( tl.bfloat16) - sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM) - sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM) - cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM) + + pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) + cos_offsets = pos_idx * HEAD_DIM + tl.arange(0, HALF_HEAD_DIM) + sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) + cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) + sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) x1 = tl.extract_slice( normalized_values, offsets=(0, 0), @@ -158,23 +165,25 @@ def split_qkv_rmsnorm_rope_kernel( sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), - dtype=tl.bfloat16) - cat_x = tl.insert_slice( - cat_x, - -x2, + roped_k1 = x1 * cos - x2 * sin + roped_k2 = x2 * cos + x1 * sin + + roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), + dtype=tl.bfloat16) + roped_k = tl.insert_slice( + roped_k, + roped_k1, offsets=(0, 0), sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - cat_x = tl.insert_slice( - cat_x, - x1, + roped_k = tl.insert_slice( + roped_k, + roped_k2, offsets=(0, HALF_HEAD_DIM), sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - roped_k = cat_x * sin + normalized_values * cos tl.store( k_ptr + output_offset + col_indices, @@ -201,8 +210,8 @@ def split_qkv_rmsnorm_rope_kernel( def split_qkv_rmsnorm_rope_impl( input: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, q_hidden_size: int, @@ -238,8 +247,8 @@ def split_qkv_rmsnorm_rope_impl( split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)]( input, - sin, - cos, + cos_sin_cache, + positions, q_output, k_output, v_output, @@ -263,8 +272,8 @@ def split_qkv_rmsnorm_rope_impl( def split_qkv_rmsnorm_rope_impl_fake( input: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, q_hidden_size: int, diff --git a/vllm_ascend/ops/triton/rope.py b/vllm_ascend/ops/triton/rope.py index 3700e329130..9db17370ea7 100644 --- a/vllm_ascend/ops/triton/rope.py +++ b/vllm_ascend/ops/triton/rope.py @@ -14,6 +14,7 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # +import torch from vllm.triton_utils import tl, triton from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num @@ -28,10 +29,13 @@ def _triton_rope( q_row_stride, k_ptr, k_row_stride, - cos, + cos_ptr, cos_row_stride, - sin, + sin_ptr, sin_row_stride, + cos_sin_ptr, + cos_sin_row_stride, + pos_ptr, num_tokens, n_qh: tl.constexpr, n_kh: tl.constexpr, @@ -42,6 +46,7 @@ def _triton_rope( pad_rope_dim: tl.constexpr, BLOCK_SIZE: tl.constexpr, IS_NEOX_STYLE: tl.constexpr, + USE_COS_SIN: tl.constexpr, ): """ This triton kernel applies rotary embedding on q and k. @@ -82,15 +87,31 @@ def _triton_rope( # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position # m of this program instance # #################################################################### - cos_start_ptr = cos + row_idx * cos_row_stride - sin_start_ptr = sin + row_idx * sin_row_stride + if USE_COS_SIN: + pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) + cos_start_ptr = cos_sin_ptr + pos_idx * cos_sin_row_stride - cos_offsets = tl.arange(0, pad_rope_dim // 2) - cos_mask = cos_offsets < (rope_dim // 2) - cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, - other=0).to(tl.float32) - sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, - other=0).to(tl.float32) + cos_offsets = tl.arange(0, pad_rope_dim // 2) + sin_offsets = tl.arange(pad_rope_dim // 2, pad_rope_dim) + cos_mask = cos_offsets < (rope_dim // 2) + cos_row = tl.load(cos_start_ptr + cos_offsets, + mask=cos_mask, + other=0).to(tl.float32) + sin_row = tl.load(cos_start_ptr + sin_offsets, + mask=cos_mask, + other=0).to(tl.float32) + else: + cos_start_ptr = cos_ptr + row_idx * cos_row_stride + sin_start_ptr = sin_ptr + row_idx * sin_row_stride + cos_offsets = tl.arange(0, pad_rope_dim // 2) + sin_offsets = tl.arange(pad_rope_dim // 2, pad_rope_dim) + cos_mask = cos_offsets < (rope_dim // 2) + cos_row = tl.load(cos_start_ptr + cos_offsets, + mask=cos_mask, + other=0).to(tl.float32) + sin_row = tl.load(sin_start_ptr + cos_offsets, + mask=cos_mask, + other=0).to(tl.float32) # #################################################################### # Load the left and right half of q and k for the current @@ -157,12 +178,16 @@ def _triton_rope( mask=second_k_mask) -def rope_forward_triton(q, - k, - cos, - sin, - rope_dim: int = -1, - is_neox_style: bool = True): +def rope_forward_triton( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor = None, + sin: torch.Tensor = None, + cos_sin_cache: torch.Tensor = None, + positions: torch.Tensor = None, + rope_dim: int = -1, + is_neox_style: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: if not q.is_contiguous(): q = q.contiguous() if not k.is_contiguous(): @@ -170,12 +195,6 @@ def rope_forward_triton(q, num_tokens, n_q_head, head_dim = q.shape n_kv_head = k.shape[1] - cos = cos.view(num_tokens, -1) - sin = sin.view(num_tokens, -1) - if rope_dim == -1: - # If rope_dim is not specified, we assume that input cos/sin is not - # duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2 - rope_dim = cos.shape[-1] * 2 assert rope_dim <= head_dim pad_rope_dim = triton.next_power_of_2(rope_dim) pad_n_q_head = triton.next_power_of_2(n_q_head) @@ -184,24 +203,68 @@ def rope_forward_triton(q, num_vectorcore = get_vectorcore_num() n_row = min(num_tokens, num_vectorcore) - _triton_rope[(n_row, )]( - q, - q.stride(0), - k, - k.stride(0), - cos, - cos.stride(0), - sin, - sin.stride(0), - num_tokens, - n_q_head, - n_kv_head, - head_dim, - rope_dim, - pad_n_q_head, - pad_n_kv_head, - pad_rope_dim, - BLOCK_SIZE=BLOCK_SIZE, - IS_NEOX_STYLE=is_neox_style, - ) + if cos_sin_cache is not None and positions is not None: + assert positions.shape[0] == num_tokens + _triton_rope[(n_row, )]( + q, + q.stride(0), + k, + k.stride(0), + None, + None, + None, + None, + cos_sin_cache, + cos_sin_cache.stride(0), + positions, + num_tokens, + n_q_head, + n_kv_head, + head_dim, + rope_dim, + pad_n_q_head, + pad_n_kv_head, + pad_rope_dim, + BLOCK_SIZE=BLOCK_SIZE, + IS_NEOX_STYLE=is_neox_style, + USE_COS_SIN=True, + ) + elif cos is not None and sin is not None: + assert cos.shape[0] == num_tokens and sin.shape[0] == num_tokens + cos = cos.view(num_tokens, -1) + sin = sin.view(num_tokens, -1) + if rope_dim == -1: + # If rope_dim is not specified, we assume that input cos/sin is not + # duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2 + rope_dim = cos.shape[-1] * 2 + _triton_rope[(n_row, )]( + q, + q.stride(0), + k, + k.stride(0), + cos, + cos.stride(0), + sin, + sin.stride(0), + None, + None, + None, + num_tokens, + n_q_head, + n_kv_head, + head_dim, + rope_dim, + pad_n_q_head, + pad_n_kv_head, + pad_rope_dim, + BLOCK_SIZE=BLOCK_SIZE, + IS_NEOX_STYLE=is_neox_style, + USE_COS_SIN=False, + ) + else: + raise ValueError( + "Currently, rope_forward_triton supports passing:\n" + "1. positions and original cos_sin_cache.\n" + "2. cos and sin which are already selected by positions\n" + "Please check whether you call rope_forward_triton correctly.") return q, k diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 2f89965214a..8eed3b775ac 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -40,7 +40,6 @@ update_attn_params, update_mla_attn_dcp_pcp_params, update_mla_attn_params) -from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.ops.triton.spec_decode.utils import \ prepare_inputs_padded_kernel from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num @@ -325,8 +324,6 @@ def dummy_run(self, batch_descriptor=None, dummy_compute_logits=lambda hidden_states: None, is_profile=False): - # update global cos, sin - update_cos_sin(self.positions[:num_tokens]) multi_steps_attn_metadata = [] if not self.use_cuda_graph: @@ -506,9 +503,6 @@ def _propose( attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model()) - # update global cos, sin - update_cos_sin(self.positions[:num_input_tokens]) - used_update_positions = target_positions[last_token_indices] per_layer_attn_metadata = dict() # The first step of speculative. @@ -654,9 +648,6 @@ def _run_merged_draft( self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states - # update global cos, sin - update_cos_sin(self.positions[:input_batch_size]) - # Run the model. # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.