diff --git a/CMakeLists.txt b/CMakeLists.txt index be2870f30a6a..e59bfef6fc68 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -311,14 +311,9 @@ set(VLLM_EXT_SRC "csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v2.cu" "csrc/attention/merge_attn_states.cu" - "csrc/pos_encoding_kernels.cu" - "csrc/layernorm_kernels.cu" - "csrc/fused_qknorm_rope_kernel.cu" - "csrc/layernorm_quant_kernels.cu" "csrc/sampler.cu" "csrc/topk.cu" "csrc/cuda_view.cu" - "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu" "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" @@ -633,7 +628,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") "csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu" "csrc/libtorch_stable/quantization/w8a8/fp8/common.cu" "csrc/libtorch_stable/quantization/gptq/q_gemm.cu" - "csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu") + "csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu" + "csrc/libtorch_stable/pos_encoding_kernels.cu" + "csrc/libtorch_stable/fused_qknorm_rope_kernel.cu" + "csrc/libtorch_stable/layernorm_kernels.cu" + "csrc/libtorch_stable/layernorm_quant_kernels.cu" + "csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_STABLE_EXT_SRC diff --git a/csrc/libtorch_stable/dispatch_utils.h b/csrc/libtorch_stable/dispatch_utils.h index ffc2ca031260..e9478236a0e1 100644 --- a/csrc/libtorch_stable/dispatch_utils.h +++ b/csrc/libtorch_stable/dispatch_utils.h @@ -58,6 +58,35 @@ THO_DISPATCH_SWITCH(TYPE, NAME, \ VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__)) +// Quant type dispatch (FP8 + INT8) +#ifdef USE_ROCM + #define VLLM_STABLE_DISPATCH_CASE_QUANT_TYPES(...) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float8_e4m3fn, \ + __VA_ARGS__) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float8_e4m3fnuz, \ + __VA_ARGS__) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::Char, __VA_ARGS__) +#else + #define VLLM_STABLE_DISPATCH_CASE_QUANT_TYPES(...) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float8_e4m3fn, \ + __VA_ARGS__) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::Char, __VA_ARGS__) +#endif + +#define VLLM_STABLE_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ + THO_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_STABLE_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) + +// Group size dispatch (pure C++ if/else, no ATen dependency) +#define VLLM_STABLE_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \ + if (group_size == 128) { \ + constexpr int const_group_size = 128; \ + __VA_ARGS__(); \ + } else if (group_size == 64) { \ + constexpr int const_group_size = 64; \ + __VA_ARGS__(); \ + } + // Boolean dispatch #define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \ if (expr) { \ @@ -67,3 +96,56 @@ constexpr bool const_expr = false; \ __VA_ARGS__(); \ } + +// Vec size dispatch (pure C++ switch, no ATen dependency) +#define VLLM_STABLE_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \ + switch (VEC_SIZE) { \ + case 16: { \ + constexpr int vec_size = 16; \ + __VA_ARGS__(); \ + break; \ + } \ + case 8: { \ + constexpr int vec_size = 8; \ + __VA_ARGS__(); \ + break; \ + } \ + case 4: { \ + constexpr int vec_size = 4; \ + __VA_ARGS__(); \ + break; \ + } \ + case 2: { \ + constexpr int vec_size = 2; \ + __VA_ARGS__(); \ + break; \ + } \ + default: { \ + constexpr int vec_size = 1; \ + __VA_ARGS__(); \ + break; \ + } \ + } + +// Tensor rank dispatch (2D, 3D, 4D) +#define VLLM_STABLE_DISPATCH_RANK234(NUM_DIMS, ...) \ + switch (NUM_DIMS) { \ + case 2: { \ + constexpr int tensor_rank = 2; \ + __VA_ARGS__(); \ + break; \ + } \ + case 3: { \ + constexpr int tensor_rank = 3; \ + __VA_ARGS__(); \ + break; \ + } \ + case 4: { \ + constexpr int tensor_rank = 4; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ + STD_TORCH_CHECK( \ + false, "Expects rank 2, 3 or 4 tensors but got unsupported rank"); \ + } diff --git a/csrc/fused_qknorm_rope_kernel.cu b/csrc/libtorch_stable/fused_qknorm_rope_kernel.cu similarity index 87% rename from csrc/fused_qknorm_rope_kernel.cu rename to csrc/libtorch_stable/fused_qknorm_rope_kernel.cu index 0bf48fd3e831..bcf0ae585478 100644 --- a/csrc/fused_qknorm_rope_kernel.cu +++ b/csrc/libtorch_stable/fused_qknorm_rope_kernel.cu @@ -18,21 +18,20 @@ #include #include -#include -#include -#include +#include "torch_utils.h" -#include "async_util.cuh" -#include "cuda_compat.h" +#include "../async_util.cuh" +#include "../cuda_compat.h" +#include "../type_convert.cuh" #include "dispatch_utils.h" -#include "type_convert.cuh" -#define CHECK_TYPE(x, st) \ - TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \ - ", while ", st, " is expected") -#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_TYPE(x, st) \ + STD_TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \ + ", while ", st, " is expected") +#define CHECK_TH_CUDA(x) \ + STD_TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + STD_TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_TH_CUDA(x); \ CHECK_CONTIGUOUS(x) @@ -589,8 +588,8 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, }); break; default: - TORCH_CHECK(false, - "Unsupported head dimension for fusedQKNormRope: ", head_dim); + STD_TORCH_CHECK( + false, "Unsupported head dimension for fusedQKNormRope: ", head_dim); } } @@ -604,10 +603,10 @@ void launchFusedQKNormRopeNTokenHeads( void const* k_weight, void const* cos_sin_cache, bool const interleave, int64_t const* position_ids, int const token_heads_per_warp, cudaStream_t stream) { - TORCH_CHECK(token_heads_per_warp == 1 || token_heads_per_warp == 2 || - token_heads_per_warp == 4 || token_heads_per_warp == 8, - "token_heads_per_warp must be 1, 2, 4, or 8, got ", - token_heads_per_warp); + STD_TORCH_CHECK(token_heads_per_warp == 1 || token_heads_per_warp == 2 || + token_heads_per_warp == 4 || token_heads_per_warp == 8, + "token_heads_per_warp must be 1, 2, 4, or 8, got ", + token_heads_per_warp); // token_heads_per_warp == 1: delegate to the 1-head baseline kernel. if (token_heads_per_warp == 1) { @@ -691,7 +690,7 @@ void launchFusedQKNormRopeNTokenHeads( }); \ break; \ default: \ - TORCH_CHECK(false, "Unsupported head dimension: ", head_dim); \ + STD_TORCH_CHECK(false, "Unsupported head dimension: ", head_dim); \ } \ } while (0) @@ -708,19 +707,21 @@ void launchFusedQKNormRopeNTokenHeads( } // namespace tensorrt_llm::kernels void fused_qk_norm_rope( - torch::Tensor& qkv, // Combined QKV tensor [num_tokens, - // (num_heads_q+num_heads_k+num_heads_v)*head_dim] - int64_t num_heads_q, // Number of query heads - int64_t num_heads_k, // Number of key heads - int64_t num_heads_v, // Number of value heads - int64_t head_dim, // Dimension per head - double eps, // Epsilon for RMS normalization - torch::Tensor& q_weight, // RMSNorm weights for query [head_dim] - torch::Tensor& k_weight, // RMSNorm weights for key [head_dim] - torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim] - bool is_neox, // Whether RoPE is applied in Neox style - torch::Tensor& position_ids, // Position IDs for RoPE [num_tokens] - int64_t forced_token_heads_per_warp // -1 = auto-select, >0 = forced value + torch::stable::Tensor& + qkv, // Combined QKV tensor [num_tokens, + // (num_heads_q+num_heads_k+num_heads_v)*head_dim] + int64_t num_heads_q, // Number of query heads + int64_t num_heads_k, // Number of key heads + int64_t num_heads_v, // Number of value heads + int64_t head_dim, // Dimension per head + double eps, // Epsilon for RMS normalization + torch::stable::Tensor& q_weight, // RMSNorm weights for query [head_dim] + torch::stable::Tensor& k_weight, // RMSNorm weights for key [head_dim] + torch::stable::Tensor& cos_sin_cache, // Cos/sin cache [max_position, + // head_dim] + bool is_neox, // Whether RoPE is applied in Neox style + torch::stable::Tensor& position_ids, // Position IDs for RoPE [num_tokens] + int64_t forced_token_heads_per_warp // -1 = auto-select, >0 = forced value ) { // Input validation CHECK_INPUT(qkv); @@ -728,40 +729,42 @@ void fused_qk_norm_rope( CHECK_INPUT(q_weight); CHECK_INPUT(k_weight); CHECK_INPUT(cos_sin_cache); - CHECK_TYPE(position_ids, torch::kInt64); - - TORCH_CHECK(qkv.dim() == 2, - "QKV tensor must be 2D: [num_tokens, " - "(num_heads_q+num_heads_k+num_heads_v)*head_dim]"); - TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]"); - TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]"); - TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]"); - TORCH_CHECK(cos_sin_cache.dim() == 2, - "Cos/sin cache must be 2D: [max_position, head_dim]"); - TORCH_CHECK(q_weight.size(0) == head_dim, - "Query weights size must match head dimension"); - TORCH_CHECK(k_weight.size(0) == head_dim, - "Key weights size must match head dimension"); - - TORCH_CHECK(cos_sin_cache.size(1) % 2 == 0, "rotary_dim must be even"); - TORCH_CHECK(cos_sin_cache.size(1) <= head_dim, - "rotary_dim must be less than or equal to head_dim"); - - TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() && - qkv.scalar_type() == k_weight.scalar_type(), - "qkv, q_weight and k_weight must have the same dtype"); + CHECK_TYPE(position_ids, torch::headeronly::ScalarType::Long); + + STD_TORCH_CHECK(qkv.dim() == 2, + "QKV tensor must be 2D: [num_tokens, " + "(num_heads_q+num_heads_k+num_heads_v)*head_dim]"); + STD_TORCH_CHECK(position_ids.dim() == 1, + "Position IDs must be 1D: [num_tokens]"); + STD_TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]"); + STD_TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]"); + STD_TORCH_CHECK(cos_sin_cache.dim() == 2, + "Cos/sin cache must be 2D: [max_position, head_dim]"); + STD_TORCH_CHECK(q_weight.size(0) == head_dim, + "Query weights size must match head dimension"); + STD_TORCH_CHECK(k_weight.size(0) == head_dim, + "Key weights size must match head dimension"); + + STD_TORCH_CHECK(cos_sin_cache.size(1) % 2 == 0, "rotary_dim must be even"); + STD_TORCH_CHECK(cos_sin_cache.size(1) <= head_dim, + "rotary_dim must be less than or equal to head_dim"); + + STD_TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() && + qkv.scalar_type() == k_weight.scalar_type(), + "qkv, q_weight and k_weight must have the same dtype"); int64_t num_tokens = qkv.size(0); - TORCH_CHECK(position_ids.size(0) == num_tokens, - "Number of tokens in position_ids must match QKV"); + STD_TORCH_CHECK(position_ids.size(0) == num_tokens, + "Number of tokens in position_ids must match QKV"); int64_t total_heads = num_heads_q + num_heads_k + num_heads_v; - TORCH_CHECK( + STD_TORCH_CHECK( qkv.size(1) == total_heads * head_dim, "QKV tensor size must match total number of heads and head dimension"); - auto device_id = qkv.get_device(); - auto stream = at::cuda::getCurrentCUDAStream(device_id); + const torch::stable::accelerator::DeviceGuard device_guard( + qkv.get_device_index()); + auto stream = get_current_cuda_stream(qkv.get_device_index()); // Select token_heads_per_warp: forced value if >0, else auto-select. // Auto thresholds are calibrated on SM 9.0 (H100). On other architectures, @@ -771,8 +774,7 @@ void fused_qk_norm_rope( token_heads_per_warp = static_cast(forced_token_heads_per_warp); } else { token_heads_per_warp = 1; - auto* dev_prop = at::cuda::getDeviceProperties(device_id); - int sm_version = dev_prop->major * 10 + dev_prop->minor; + int sm_version = get_device_prop()->major * 10 + get_device_prop()->minor; int64_t total_qk_units = num_tokens * (num_heads_q + num_heads_k); if (sm_version == 90) { if (head_dim >= 256) { @@ -795,21 +797,22 @@ void fused_qk_norm_rope( } } - VLLM_DISPATCH_HALF_TYPES(qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] { - using qkv_scalar_t = scalar_t; - VLLM_DISPATCH_FLOATING_TYPES( - cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] { - using cache_scalar_t = scalar_t; - tensorrt_llm::kernels::launchFusedQKNormRopeNTokenHeads< - qkv_scalar_t, cache_scalar_t>( - qkv.data_ptr(), static_cast(num_tokens), - static_cast(num_heads_q), static_cast(num_heads_k), - static_cast(num_heads_v), static_cast(head_dim), - static_cast(cos_sin_cache.size(1)), static_cast(eps), - q_weight.data_ptr(), k_weight.data_ptr(), - cos_sin_cache.data_ptr(), !is_neox, - reinterpret_cast(position_ids.data_ptr()), - token_heads_per_warp, stream); - }); - }); + VLLM_STABLE_DISPATCH_HALF_TYPES( + qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] { + using qkv_scalar_t = scalar_t; + VLLM_STABLE_DISPATCH_FLOATING_TYPES( + cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] { + using cache_scalar_t = scalar_t; + tensorrt_llm::kernels::launchFusedQKNormRopeNTokenHeads< + qkv_scalar_t, cache_scalar_t>( + qkv.data_ptr(), static_cast(num_tokens), + static_cast(num_heads_q), static_cast(num_heads_k), + static_cast(num_heads_v), static_cast(head_dim), + static_cast(cos_sin_cache.size(1)), + static_cast(eps), q_weight.data_ptr(), + k_weight.data_ptr(), cos_sin_cache.data_ptr(), !is_neox, + reinterpret_cast(position_ids.data_ptr()), + token_heads_per_warp, stream); + }); + }); } diff --git a/csrc/layernorm_kernels.cu b/csrc/libtorch_stable/layernorm_kernels.cu similarity index 79% rename from csrc/layernorm_kernels.cu rename to csrc/libtorch_stable/layernorm_kernels.cu index e617e45dc58b..fb714b1b1e07 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/libtorch_stable/layernorm_kernels.cu @@ -1,11 +1,12 @@ -#include "type_convert.cuh" -#include "dispatch_utils.h" -#include "cub_helpers.h" -#include "core/batch_invariant.hpp" -#include "libtorch_stable/quantization/vectorization_utils.cuh" +#include + +#include "torch_utils.h" -#include -#include +#include "../cub_helpers.h" +#include "../core/batch_invariant.hpp" +#include "../type_convert.cuh" +#include "dispatch_utils.h" +#include "quantization/vectorization_utils.cuh" namespace vllm { @@ -189,16 +190,16 @@ fused_add_rms_norm_kernel( } // namespace vllm -void rms_norm(torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] +void rms_norm(torch::stable::Tensor& out, // [..., hidden_size] + torch::stable::Tensor& input, // [..., hidden_size] + torch::stable::Tensor& weight, // [hidden_size] double epsilon) { - TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(out.is_contiguous()); if (input.stride(-1) != 1) { - input = input.contiguous(); + input = torch::stable::contiguous(input); } - TORCH_CHECK(input.stride(-1) == 1); - TORCH_CHECK(weight.is_contiguous()); + STD_TORCH_CHECK(input.stride(-1) == 1); + STD_TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); @@ -213,45 +214,49 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] // For large num_tokens, use smaller blocks to increase SM concurrency. const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 grid(num_tokens); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_RANK234(num_dims, [&] { - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { - const int calculated_vec_size = - std::gcd(16 / sizeof(scalar_t), hidden_size); - const int block_size = - std::min(hidden_size / calculated_vec_size, max_block_size); - dim3 block(block_size); - VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { - vllm::rms_norm_kernel - <<>>( - out.data_ptr(), input.data_ptr(), - input_stride_d2, input_stride_d3, input_stride_d4, - input_shape_d2, input_shape_d3, weight.data_ptr(), - epsilon, num_tokens, hidden_size); - }); - }); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); + VLLM_STABLE_DISPATCH_RANK234(num_dims, [&] { + VLLM_STABLE_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_kernel", [&] { + const int calculated_vec_size = + std::gcd(16 / sizeof(scalar_t), hidden_size); + const int block_size = + std::min(hidden_size / calculated_vec_size, max_block_size); + dim3 block(block_size); + VLLM_STABLE_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { + vllm::rms_norm_kernel + <<>>( + out.mutable_data_ptr(), + input.const_data_ptr(), input_stride_d2, + input_stride_d3, input_stride_d4, input_shape_d2, + input_shape_d3, weight.const_data_ptr(), epsilon, + num_tokens, hidden_size); + }); + }); }); } -#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ - vllm::fused_add_rms_norm_kernel \ - <<>>( \ - input.data_ptr(), input_stride, \ - residual.data_ptr(), weight.data_ptr(), \ - epsilon, num_tokens, hidden_size); \ +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>( \ + input.mutable_data_ptr(), input_stride, \ + residual.mutable_data_ptr(), \ + weight.const_data_ptr(), epsilon, num_tokens, \ + hidden_size); \ }); -void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] - torch::Tensor& residual, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] +void fused_add_rms_norm(torch::stable::Tensor& input, // [..., hidden_size] + torch::stable::Tensor& residual, // [..., hidden_size] + torch::stable::Tensor& weight, // [hidden_size] double epsilon) { - TORCH_CHECK(weight.scalar_type() == input.scalar_type()); - TORCH_CHECK(input.scalar_type() == residual.scalar_type()); - TORCH_CHECK(residual.is_contiguous()); - TORCH_CHECK(weight.is_contiguous()); + STD_TORCH_CHECK(weight.scalar_type() == input.scalar_type()); + STD_TORCH_CHECK(input.scalar_type() == residual.scalar_type()); + STD_TORCH_CHECK(residual.is_contiguous()); + STD_TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); int64_t input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; @@ -263,8 +268,9 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] hiding on global mem ops. */ const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 block(std::min(hidden_size, max_block_size)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); /*If the tensor types are FP16/BF16, try to use the optimized kernel with packed + vectorized ops. Max optimization is achieved with a width-8 vector of FP16/BF16s diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/libtorch_stable/layernorm_quant_kernels.cu similarity index 82% rename from csrc/layernorm_quant_kernels.cu rename to csrc/libtorch_stable/layernorm_quant_kernels.cu index b2f546b3db1d..26ffa76d6e14 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/libtorch_stable/layernorm_quant_kernels.cu @@ -5,15 +5,16 @@ * Currently, only static fp8 quantization is supported. */ -#include "type_convert.cuh" -#include "quantization/w8a8/fp8/common.cuh" -#include "dispatch_utils.h" -#include "cub_helpers.h" -#include "core/batch_invariant.hpp" -#include "libtorch_stable/quantization/vectorization_utils.cuh" +#include + +#include "torch_utils.h" -#include -#include +#include "../cub_helpers.h" +#include "../core/batch_invariant.hpp" +#include "../quantization/w8a8/fp8/common.cuh" +#include "../type_convert.cuh" +#include "dispatch_utils.h" +#include "quantization/vectorization_utils.cuh" namespace vllm { @@ -202,12 +203,13 @@ fused_add_rms_norm_static_fp8_quant_kernel( } // namespace vllm -void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - torch::Tensor& scale, // [1] - double epsilon) { - TORCH_CHECK(out.is_contiguous()); +void rms_norm_static_fp8_quant( + torch::stable::Tensor& out, // [..., hidden_size] + torch::stable::Tensor& input, // [..., hidden_size] + torch::stable::Tensor& weight, // [hidden_size] + torch::stable::Tensor& scale, // [1] + double epsilon) { + STD_TORCH_CHECK(out.is_contiguous()); int hidden_size = input.size(-1); int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; @@ -215,24 +217,26 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] // For large num_tokens, use smaller blocks to increase SM concurrency. const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 grid(num_tokens); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_kernel_scalar_type", [&] { - VLLM_DISPATCH_FP8_TYPES( + VLLM_STABLE_DISPATCH_FP8_TYPES( out.scalar_type(), "rms_norm_kernel_fp8_type", [&] { const int calculated_vec_size = std::gcd(16 / sizeof(scalar_t), hidden_size); const int block_size = std::min(hidden_size / calculated_vec_size, max_block_size); dim3 block(block_size); - VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { + VLLM_STABLE_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { vllm::rms_norm_static_fp8_quant_kernel <<>>( - out.data_ptr(), input.data_ptr(), - input_stride, weight.data_ptr(), - scale.data_ptr(), epsilon, num_tokens, + out.mutable_data_ptr(), + input.const_data_ptr(), input_stride, + weight.const_data_ptr(), + scale.const_data_ptr(), epsilon, num_tokens, hidden_size); }); }); @@ -240,30 +244,32 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] } #define LAUNCH_FUSED_ADD_RMS_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES( \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \ - VLLM_DISPATCH_FP8_TYPES( \ + VLLM_STABLE_DISPATCH_FP8_TYPES( \ out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \ vllm::fused_add_rms_norm_static_fp8_quant_kernel \ <<>>( \ - out.data_ptr(), input.data_ptr(), \ - input_stride, residual.data_ptr(), \ - weight.data_ptr(), scale.data_ptr(), \ - epsilon, num_tokens, hidden_size); \ + out.mutable_data_ptr(), \ + input.mutable_data_ptr(), input_stride, \ + residual.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + scale.const_data_ptr(), epsilon, num_tokens, \ + hidden_size); \ }); \ }); void fused_add_rms_norm_static_fp8_quant( - torch::Tensor& out, // [..., hidden_size], - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& residual, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - torch::Tensor& scale, // [1] + torch::stable::Tensor& out, // [..., hidden_size], + torch::stable::Tensor& input, // [..., hidden_size] + torch::stable::Tensor& residual, // [..., hidden_size] + torch::stable::Tensor& weight, // [hidden_size] + torch::stable::Tensor& scale, // [1] double epsilon) { - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(residual.is_contiguous()); - TORCH_CHECK(residual.scalar_type() == input.scalar_type()); - TORCH_CHECK(weight.scalar_type() == input.scalar_type()); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(residual.is_contiguous()); + STD_TORCH_CHECK(residual.scalar_type() == input.scalar_type()); + STD_TORCH_CHECK(weight.scalar_type() == input.scalar_type()); int hidden_size = input.size(-1); int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; @@ -275,8 +281,9 @@ void fused_add_rms_norm_static_fp8_quant( hiding on global mem ops. */ const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 block(std::min(hidden_size, max_block_size)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); /*If the tensor types are FP16/BF16, try to use the optimized kernel with packed + vectorized ops. Max optimization is achieved with a width-8 vector of FP16/BF16s diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h index 5ebcb2034f53..f99ff1d1db5b 100644 --- a/csrc/libtorch_stable/ops.h +++ b/csrc/libtorch_stable/ops.h @@ -167,6 +167,59 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel, torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x, bool inplace); +// Layernorm kernels (shared CUDA/ROCm) +void rms_norm(torch::stable::Tensor& out, torch::stable::Tensor& input, + torch::stable::Tensor& weight, double epsilon); + +void fused_add_rms_norm(torch::stable::Tensor& input, + torch::stable::Tensor& residual, + torch::stable::Tensor& weight, double epsilon); + +// Layernorm-quant kernels (shared CUDA/ROCm) +void rms_norm_static_fp8_quant(torch::stable::Tensor& out, + torch::stable::Tensor& input, + torch::stable::Tensor& weight, + torch::stable::Tensor& scale, double epsilon); + +void fused_add_rms_norm_static_fp8_quant(torch::stable::Tensor& out, + torch::stable::Tensor& input, + torch::stable::Tensor& residual, + torch::stable::Tensor& weight, + torch::stable::Tensor& scale, + double epsilon); + +// Fused layernorm + dynamic per-token quant kernels (shared CUDA/ROCm) +void rms_norm_dynamic_per_token_quant( + torch::stable::Tensor& out, torch::stable::Tensor const& input, + torch::stable::Tensor const& weight, torch::stable::Tensor& scales, + double const var_epsilon, std::optional scale_ub, + std::optional residual); + +void rms_norm_per_block_quant(torch::stable::Tensor& out, + torch::stable::Tensor const& input, + torch::stable::Tensor const& weight, + torch::stable::Tensor& scales, + double const var_epsilon, + std::optional scale_ub, + std::optional residual, + int64_t group_size, bool is_scale_transposed); + +// Positional encoding kernels (shared CUDA/ROCm) +void rotary_embedding(torch::stable::Tensor& positions, + torch::stable::Tensor& query, + std::optional key, + int64_t head_size, torch::stable::Tensor& cos_sin_cache, + bool is_neox, int64_t rope_dim_offset, bool inverse); + +void fused_qk_norm_rope(torch::stable::Tensor& qkv, int64_t num_heads_q, + int64_t num_heads_k, int64_t num_heads_v, + int64_t head_dim, double eps, + torch::stable::Tensor& q_weight, + torch::stable::Tensor& k_weight, + torch::stable::Tensor& cos_sin_cache, bool is_neox, + torch::stable::Tensor& position_ids, + int64_t forced_token_heads_per_warp); + // Activation kernels (shared CUDA/ROCm) void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input); void silu_and_mul_clamp(torch::stable::Tensor& out, diff --git a/csrc/pos_encoding_kernels.cu b/csrc/libtorch_stable/pos_encoding_kernels.cu similarity index 70% rename from csrc/pos_encoding_kernels.cu rename to csrc/libtorch_stable/pos_encoding_kernels.cu index d03c6a5cf0dd..74af743b0961 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/libtorch_stable/pos_encoding_kernels.cu @@ -1,8 +1,6 @@ -#include -#include -#include +#include "torch_utils.h" -#include "cuda_compat.h" +#include "../cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -103,35 +101,37 @@ __global__ void rotary_embedding_kernel( } // namespace vllm void rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] - std::optional key, + torch::stable::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::stable::Tensor& + query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] + std::optional key, // null or // [batch_size, seq_len, num_kv_heads * head_size] or // [num_tokens, num_kv_heads * head_size] or // [batch_size, seq_len, num_heads, head_size] or // [num_tokens, num_heads, head_size] int64_t head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + torch::stable::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox, int64_t rope_dim_offset, bool inverse) { // num_tokens = batch_size * seq_len int64_t num_tokens = positions.numel(); int positions_ndim = positions.dim(); // Make sure num_tokens dim is consistent across positions, query, and key - TORCH_CHECK( + STD_TORCH_CHECK( positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); if (positions_ndim == 1) { - TORCH_CHECK(query.size(0) == positions.size(0) && - (!key.has_value() || key->size(0) == positions.size(0)), - "query, key and positions must have the same number of tokens"); + STD_TORCH_CHECK( + query.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)), + "query, key and positions must have the same number of tokens"); } if (positions_ndim == 2) { - TORCH_CHECK( + STD_TORCH_CHECK( query.size(0) == positions.size(0) && (!key.has_value() || key->size(0) == positions.size(0)) && query.size(1) == positions.size(1) && @@ -143,20 +143,20 @@ void rotary_embedding( // hidden_size = num_heads * head_size int query_hidden_size = query.numel() / num_tokens; int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; - TORCH_CHECK(query_hidden_size % head_size == 0); - TORCH_CHECK(key_hidden_size % head_size == 0); + STD_TORCH_CHECK(query_hidden_size % head_size == 0); + STD_TORCH_CHECK(key_hidden_size % head_size == 0); // Make sure query and key have consistent number of heads int num_heads = query_hidden_size / head_size; int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; - TORCH_CHECK(num_heads % num_kv_heads == 0); + STD_TORCH_CHECK(num_heads % num_kv_heads == 0); int rot_dim = cos_sin_cache.size(1); int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; - TORCH_CHECK((rot_dim + rope_dim_offset) <= head_size); + STD_TORCH_CHECK((rot_dim + rope_dim_offset) <= head_size); // Determine head stride: for [*, heads, head_size] use stride of last dim; // for flat [*, heads*head_size], heads blocks are contiguous of size // head_size @@ -166,30 +166,36 @@ void rotary_embedding( dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { - using query_t = scalar_t; - VLLM_DISPATCH_FLOATING_TYPES( - cos_sin_cache.scalar_type(), "rotary_embedding_cache", [&] { - using cache_t = scalar_t; - if (is_neox) { - vllm::rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.has_value() ? key->data_ptr() : nullptr, - cos_sin_cache.data_ptr(), rot_dim, query_stride, - key_stride, head_stride, num_heads, num_kv_heads, head_size, - rope_dim_offset, inverse); - } else { - vllm::rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.has_value() ? key->data_ptr() : nullptr, - cos_sin_cache.data_ptr(), rot_dim, query_stride, - key_stride, head_stride, num_heads, num_kv_heads, head_size, - rope_dim_offset, inverse); - } - }); - }); + const torch::stable::accelerator::DeviceGuard device_guard( + query.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); + VLLM_STABLE_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "rotary_embedding", [&] { + using query_t = scalar_t; + VLLM_STABLE_DISPATCH_FLOATING_TYPES( + cos_sin_cache.scalar_type(), "rotary_embedding_cache", [&] { + using cache_t = scalar_t; + if (is_neox) { + vllm::rotary_embedding_kernel + <<>>( + positions.const_data_ptr(), + query.mutable_data_ptr(), + key.has_value() ? key->mutable_data_ptr() + : nullptr, + cos_sin_cache.const_data_ptr(), rot_dim, + query_stride, key_stride, head_stride, num_heads, + num_kv_heads, head_size, rope_dim_offset, inverse); + } else { + vllm::rotary_embedding_kernel + <<>>( + positions.const_data_ptr(), + query.mutable_data_ptr(), + key.has_value() ? key->mutable_data_ptr() + : nullptr, + cos_sin_cache.const_data_ptr(), rot_dim, + query_stride, key_stride, head_stride, num_heads, + num_kv_heads, head_size, rope_dim_offset, inverse); + } + }); + }); } diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu similarity index 54% rename from csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu rename to csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 723ca8142b82..2152e64dc962 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -1,6 +1,5 @@ -#include -#include +#include "../../torch_utils.h" #include "../../dispatch_utils.h" #include "layernorm_utils.cuh" @@ -134,63 +133,71 @@ __global__ void rms_norm_per_block_quant_kernel( // Residual add + RMS norm + dynamic per token template void rms_norm_dynamic_per_token_quant_dispatch( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& weight, // [hidden_size] - torch::Tensor& scales, // [num_tokens] - double const var_epsilon, // Variance epsilon used in norm calculation - std::optional const& scale_ub, - std::optional& residual) { + torch::stable::Tensor& out, // [..., hidden_size] + torch::stable::Tensor const& input, // [..., hidden_size] + torch::stable::Tensor const& weight, // [hidden_size] + torch::stable::Tensor& scales, // [num_tokens] + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional const& scale_ub, + std::optional& residual) { int32_t hidden_size = input.size(-1); - int32_t input_stride = input.view({-1, hidden_size}).stride(0); + int32_t input_stride = + torch::stable::view(input, {-1, hidden_size}).stride(0); auto num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); - VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { - VLLM_DISPATCH_QUANT_TYPES( + VLLM_STABLE_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { + VLLM_STABLE_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { vllm::rms_norm_dynamic_per_token_quant_kernel <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + out.mutable_data_ptr(), + scales.mutable_data_ptr(), + input.const_data_ptr(), + weight.const_data_ptr(), + scale_ub.has_value() ? scale_ub->const_data_ptr() + : nullptr, var_epsilon, hidden_size, input_stride, - has_residual ? residual->data_ptr() : nullptr); + has_residual ? residual->mutable_data_ptr() + : nullptr); }); }); } void rms_norm_dynamic_per_token_quant( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& weight, // [hidden_size] - torch::Tensor& scales, // [num_tokens] - double const var_epsilon, // Variance epsilon used in norm calculation - std::optional scale_ub, std::optional residual) { - static c10::ScalarType kFp8Type = is_fp8_ocp() - ? c10::ScalarType::Float8_e4m3fn - : c10::ScalarType::Float8_e4m3fnuz; - TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(input.stride(-1) == 1, - "Input must be contiguous in the last dimension"); + torch::stable::Tensor& out, // [..., hidden_size] + torch::stable::Tensor const& input, // [..., hidden_size] + torch::stable::Tensor const& weight, // [hidden_size] + torch::stable::Tensor& scales, // [num_tokens] + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional scale_ub, + std::optional residual) { + static torch::headeronly::ScalarType kFp8Type = + is_fp8_ocp() ? torch::headeronly::ScalarType::Float8_e4m3fn + : torch::headeronly::ScalarType::Float8_e4m3fnuz; + STD_TORCH_CHECK(out.scalar_type() == kFp8Type || + out.scalar_type() == torch::headeronly::ScalarType::Char); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(input.stride(-1) == 1, + "Input must be contiguous in the last dimension"); if (scale_ub.has_value()) { - TORCH_CHECK(out.dtype() == kFp8Type); + STD_TORCH_CHECK(out.scalar_type() == kFp8Type); } - TORCH_CHECK(weight.dtype() == input.dtype()); - TORCH_CHECK(scales.dtype() == torch::kFloat32); + STD_TORCH_CHECK(weight.scalar_type() == input.scalar_type()); + STD_TORCH_CHECK(scales.scalar_type() == torch::headeronly::ScalarType::Float); if (residual) { - TORCH_CHECK(residual->scalar_type() == input.scalar_type()); - TORCH_CHECK(residual->is_contiguous()); + STD_TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + STD_TORCH_CHECK(residual->is_contiguous()); } - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { rms_norm_dynamic_per_token_quant_dispatch( out, input, weight, scales, var_epsilon, scale_ub, residual); @@ -199,103 +206,115 @@ void rms_norm_dynamic_per_token_quant( // Residual add + RMS norm + dynamic per token void rms_norm_per_block_quant_dispatch( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& weight, // [hidden_size] - torch::Tensor& scales, // [num_tokens, hidden_size / group_size] or - // [hidden_size / group_size, num_tokens] + torch::stable::Tensor& out, // [..., hidden_size] + torch::stable::Tensor const& input, // [..., hidden_size] + torch::stable::Tensor const& weight, // [hidden_size] + torch::stable::Tensor& scales, // [num_tokens, hidden_size / + // group_size] or + // [hidden_size / group_size, + // num_tokens] int32_t group_size, double const var_epsilon, // Variance epsilon used in norm calculation - std::optional const& scale_ub, - std::optional& residual, bool is_scale_transposed) { + std::optional const& scale_ub, + std::optional& residual, bool is_scale_transposed) { int32_t hidden_size = input.size(-1); - int32_t input_stride = input.view({-1, hidden_size}).stride(0); + int32_t input_stride = + torch::stable::view(input, {-1, hidden_size}).stride(0); - TORCH_CHECK(hidden_size % 4 == 0, - "Hidden size must be divisible by 4 for vectorized access"); - TORCH_CHECK(input_stride % 4 == 0, - "Input stride must be divisible by 4 for vectorized access"); - TORCH_CHECK(group_size % 4 == 0, - "Group size must be divisible by 4 for vectorized access"); + STD_TORCH_CHECK(hidden_size % 4 == 0, + "Hidden size must be divisible by 4 for vectorized access"); + STD_TORCH_CHECK(input_stride % 4 == 0, + "Input stride must be divisible by 4 for vectorized access"); + STD_TORCH_CHECK(group_size % 4 == 0, + "Group size must be divisible by 4 for vectorized access"); auto num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); const int max_block_size = (num_tokens <= 256) ? 512 : 256; dim3 block(std::min(hidden_size, max_block_size)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_per_block_quant_fp_dispatch", [&] { using scalar_in_t = scalar_t; - VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] { - VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { - VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] { - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { - vllm::rms_norm_per_block_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, + VLLM_STABLE_DISPATCH_GROUP_SIZE(group_size, gs, [&] { + VLLM_STABLE_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { + VLLM_STABLE_DISPATCH_BOOL( + is_scale_transposed, transpose_scale, [&] { + VLLM_STABLE_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel", + [&] { + vllm::rms_norm_per_block_quant_kernel< + scalar_in_t, scalar_t, has_residual, + transpose_scale, gs><<>>( + out.mutable_data_ptr(), + scales.mutable_data_ptr(), + input.const_data_ptr(), + weight.const_data_ptr(), + scale_ub.has_value() + ? scale_ub->const_data_ptr() + : nullptr, var_epsilon, hidden_size, input_stride, - has_residual ? residual->data_ptr() - : nullptr, + has_residual + ? residual->mutable_data_ptr() + : nullptr, scales.stride(1)); - }); - }); + }); + }); }); }); }); } -void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor const& weight, - torch::Tensor& scales, double const var_epsilon, - std::optional scale_ub, - std::optional residual, +void rms_norm_per_block_quant(torch::stable::Tensor& out, + torch::stable::Tensor const& input, + torch::stable::Tensor const& weight, + torch::stable::Tensor& scales, + double const var_epsilon, + std::optional scale_ub, + std::optional residual, int64_t group_size, bool is_scale_transposed) { - static c10::ScalarType kFp8Type = is_fp8_ocp() - ? c10::ScalarType::Float8_e4m3fn - : c10::ScalarType::Float8_e4m3fnuz; - TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(input.stride(-1) == 1, - "Input must be contiguous in the last dimension"); + static torch::headeronly::ScalarType kFp8Type = + is_fp8_ocp() ? torch::headeronly::ScalarType::Float8_e4m3fn + : torch::headeronly::ScalarType::Float8_e4m3fnuz; + STD_TORCH_CHECK(out.scalar_type() == kFp8Type || + out.scalar_type() == torch::headeronly::ScalarType::Char); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(input.stride(-1) == 1, + "Input must be contiguous in the last dimension"); if (scale_ub.has_value()) { - TORCH_CHECK(out.dtype() == kFp8Type); + STD_TORCH_CHECK(out.scalar_type() == kFp8Type); } - TORCH_CHECK(weight.dtype() == input.dtype()); - TORCH_CHECK(scales.dtype() == torch::kFloat32); + STD_TORCH_CHECK(weight.scalar_type() == input.scalar_type()); + STD_TORCH_CHECK(scales.scalar_type() == torch::headeronly::ScalarType::Float); if (residual) { - TORCH_CHECK(residual->scalar_type() == input.scalar_type()); - TORCH_CHECK(residual->is_contiguous()); + STD_TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + STD_TORCH_CHECK(residual->is_contiguous()); } - TORCH_CHECK(group_size == 128 || group_size == 64, - "Unsupported group size: ", group_size); + STD_TORCH_CHECK(group_size == 128 || group_size == 64, + "Unsupported group size: ", group_size); if (scales.stride(1) > 1) { - TORCH_CHECK(is_scale_transposed, - "Outer scale stride must be 1 when scales are not transposed"); + STD_TORCH_CHECK( + is_scale_transposed, + "Outer scale stride must be 1 when scales are not transposed"); } int64_t hidden_size = input.size(-1); - TORCH_CHECK(hidden_size > 0 && hidden_size % group_size == 0, - "hidden_size must be a positive multiple of group_size"); + STD_TORCH_CHECK(hidden_size > 0 && hidden_size % group_size == 0, + "hidden_size must be a positive multiple of group_size"); int64_t num_tokens = input.numel() / hidden_size; int64_t num_groups = hidden_size / group_size; - TORCH_CHECK(scales.numel() >= num_tokens * num_groups, - "scales buffer too small: need ", num_tokens * num_groups, - " elements, got ", scales.numel()); + STD_TORCH_CHECK(scales.numel() >= num_tokens * num_groups, + "scales buffer too small: need ", num_tokens * num_groups, + " elements, got ", scales.numel()); rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size, var_epsilon, scale_ub, residual, is_scale_transposed); -} \ No newline at end of file +} diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/libtorch_stable/quantization/fused_kernels/layernorm_utils.cuh similarity index 99% rename from csrc/quantization/fused_kernels/layernorm_utils.cuh rename to csrc/libtorch_stable/quantization/fused_kernels/layernorm_utils.cuh index 48b615ebdd95..290abedcf940 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/libtorch_stable/quantization/fused_kernels/layernorm_utils.cuh @@ -8,8 +8,8 @@ #include "quantization/utils.cuh" #include "quant_conversions.cuh" -#include "../../cub_helpers.h" -#include "../../cuda_compat.h" +#include "../../../cub_helpers.h" +#include "../../../cuda_compat.h" namespace vllm { diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/libtorch_stable/quantization/fused_kernels/quant_conversions.cuh similarity index 98% rename from csrc/quantization/fused_kernels/quant_conversions.cuh rename to csrc/libtorch_stable/quantization/fused_kernels/quant_conversions.cuh index fc60643777e0..dbe38092f956 100644 --- a/csrc/quantization/fused_kernels/quant_conversions.cuh +++ b/csrc/libtorch_stable/quantization/fused_kernels/quant_conversions.cuh @@ -6,7 +6,7 @@ #include "libtorch_stable/quantization/vectorization.cuh" // TODO(luka/varun):refactor common.cuh to use this file instead -#include "../w8a8/fp8/common.cuh" +#include "../../../quantization/w8a8/fp8/common.cuh" namespace vllm { diff --git a/csrc/libtorch_stable/torch_bindings.cpp b/csrc/libtorch_stable/torch_bindings.cpp index ee0af3da560c..1601c3bd5bfa 100644 --- a/csrc/libtorch_stable/torch_bindings.cpp +++ b/csrc/libtorch_stable/torch_bindings.cpp @@ -267,9 +267,62 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { // conditionally compiled so impl registration is in source file ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor"); + // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.def( + "rms_norm(Tensor! result, Tensor input, Tensor weight, float epsilon) -> " + "()"); + + // In-place fused Add and RMS Normalization. + ops.def( + "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " + "float epsilon) -> ()"); + + // Layernorm-quant + // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.def( + "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, " + "Tensor scale, float epsilon) -> " + "()"); + + // In-place fused Add and RMS Normalization. + ops.def( + "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, " + "Tensor! residual, Tensor weight, " + "Tensor scale, float epsilon) -> ()"); + + // Fused Layernorm + Quant kernels + ops.def( + "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, " + "Tensor weight, Tensor! scale, float epsilon, " + "Tensor? scale_ub, Tensor!? residual) -> ()"); + + // Fused Layernorm + Block quant kernels + ops.def( + "rms_norm_per_block_quant(Tensor! result, Tensor input, " + "Tensor weight, Tensor! scale, float epsilon, " + "Tensor? scale_ub, Tensor!? residual, int group_size, " + "bool is_scale_transposed) -> ()"); + + // Rotary embedding + // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. + ops.def( + "rotary_embedding(Tensor positions, Tensor! query," + " Tensor!? key, int head_size," + " Tensor cos_sin_cache, bool is_neox, int " + "rope_dim_offset=0, bool inverse=False) -> ()"); + + // Function for fused QK Norm and RoPE + ops.def( + "fused_qk_norm_rope(Tensor! qkv, int num_heads_q, " + "int num_heads_k, int num_heads_v, int head_dim, float eps, " + "Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, " + "bool is_neox, Tensor position_ids, " + "int forced_token_heads_per_warp=-1) -> ()"); + // Activation ops // Activation function used in SwiGLU. ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); // SwiGLU activation with input clamping. @@ -416,6 +469,24 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { // files (allspark_repack.cu and allspark_qgemm_w8a16.cu) #endif + // Layernorm kernels (shared CUDA/ROCm) + ops.impl("rms_norm", TORCH_BOX(&rms_norm)); + ops.impl("fused_add_rms_norm", TORCH_BOX(&fused_add_rms_norm)); + + // Layernorm-quant kernels (shared CUDA/ROCm) + ops.impl("rms_norm_static_fp8_quant", TORCH_BOX(&rms_norm_static_fp8_quant)); + ops.impl("fused_add_rms_norm_static_fp8_quant", + TORCH_BOX(&fused_add_rms_norm_static_fp8_quant)); + + // Fused layernorm + dynamic per-token quant kernels (shared CUDA/ROCm) + ops.impl("rms_norm_dynamic_per_token_quant", + TORCH_BOX(&rms_norm_dynamic_per_token_quant)); + ops.impl("rms_norm_per_block_quant", TORCH_BOX(&rms_norm_per_block_quant)); + + // Positional encoding kernels (shared CUDA/ROCm) + ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding)); + ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope)); + // Activation kernels (shared CUDA/ROCm) ops.impl("silu_and_mul", TORCH_BOX(&silu_and_mul)); ops.impl("mul_and_silu", TORCH_BOX(&mul_and_silu)); diff --git a/csrc/ops.h b/csrc/ops.h index d3db38aebfbd..3e2faac5b7a4 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -61,19 +61,15 @@ void merge_attn_states( const std::optional prefill_tokens_with_context, const std::optional& output_scale = std::nullopt); +// rms_norm and fused_add_rms_norm declarations also exist in +// csrc/libtorch_stable/ops.h (torch::stable ABI for CUDA). They remain here +// because the CPU build still uses these torch::Tensor declarations. void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); -void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q, - int64_t num_heads_k, int64_t num_heads_v, - int64_t head_dim, double eps, torch::Tensor& q_weight, - torch::Tensor& k_weight, torch::Tensor& cos_sin_cache, - bool is_neox, torch::Tensor& position_ids, - int64_t forced_token_heads_per_warp); - void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( torch::Tensor& q, torch::Tensor const& kv, torch::Tensor& k_cache, torch::Tensor const& slot_mapping, torch::Tensor const& position_ids, @@ -99,37 +95,15 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, torch::Tensor& output, torch::Tensor& workspace, int64_t k, int64_t max_seq_len); -void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, - torch::Tensor& weight, torch::Tensor& scale, - double epsilon); - -void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& residual, - torch::Tensor& weight, - torch::Tensor& scale, double epsilon); - -void rms_norm_dynamic_per_token_quant(torch::Tensor& out, - torch::Tensor const& input, - torch::Tensor const& weight, - torch::Tensor& scales, - double const epsilon, - std::optional scale_ub, - std::optional residual); - -void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor const& weight, - torch::Tensor& scales, double const epsilon, - std::optional scale_ub, - std::optional residual, - int64_t group_size, bool is_scale_transposed); - void silu_and_mul_per_block_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, int64_t group_size, std::optional scale_ub, bool is_scale_transposed); +// rotary_embedding also exist in csrc/libtorch_stable/ops.h (torch::stable +// ABI for CUDA). It remains here because the CPU build still uses these +// torch::Tensor declarations. void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox, diff --git a/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu b/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu index 993ee641b5d6..d5c76232599e 100644 --- a/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu +++ b/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu @@ -5,8 +5,7 @@ #include #include "../../dispatch_utils.h" -#include "quant_conversions.cuh" -#include "../w8a8/fp8/common.cuh" +#include "libtorch_stable/quantization/fused_kernels/quant_conversions.cuh" namespace vllm { diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b88e2bb4e68f..78c2875644c8 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -94,28 +94,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("silu_and_mul_per_block_quant", torch::kCUDA, &silu_and_mul_per_block_quant); - // Layernorm - // Apply Root Mean Square (RMS) Normalization to the input tensor. - ops.def( - "rms_norm(Tensor! result, Tensor input, Tensor weight, float epsilon) -> " - "()"); - ops.impl("rms_norm", torch::kCUDA, &rms_norm); - - // In-place fused Add and RMS Normalization. - ops.def( - "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " - "float epsilon) -> ()"); - ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); - - // Function for fused QK Norm and RoPE - ops.def( - "fused_qk_norm_rope(Tensor! qkv, int num_heads_q, " - "int num_heads_k, int num_heads_v, int head_dim, float eps, " - "Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, " - "bool is_neox, Tensor position_ids, " - "int forced_token_heads_per_warp=-1) -> ()"); - ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope); - // Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and // GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one // kernel launch. @@ -152,48 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor workspace, int k, int max_seq_len) -> ()"); ops.impl("persistent_topk", torch::kCUDA, &persistent_topk); - // Layernorm-quant - // Apply Root Mean Square (RMS) Normalization to the input tensor. - ops.def( - "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, " - "Tensor scale, float epsilon) -> " - "()"); - ops.impl("rms_norm_static_fp8_quant", torch::kCUDA, - &rms_norm_static_fp8_quant); - - // In-place fused Add and RMS Normalization. - ops.def( - "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, " - "Tensor! residual, Tensor weight, " - "Tensor scale, float epsilon) -> ()"); - ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, - &fused_add_rms_norm_static_fp8_quant); - - // Fused Layernorm + Quant kernels - ops.def( - "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, " - "Tensor weight, Tensor! scale, float epsilon, " - "Tensor? scale_ub, Tensor!? residual) -> ()"); - ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, - &rms_norm_dynamic_per_token_quant); - - // Fused Layernorm + Block quant kernels - ops.def( - "rms_norm_per_block_quant(Tensor! result, Tensor input, " - "Tensor weight, Tensor! scale, float epsilon, " - "Tensor? scale_ub, Tensor!? residual, int group_size, " - "bool is_scale_transposed) -> ()"); - ops.impl("rms_norm_per_block_quant", torch::kCUDA, &rms_norm_per_block_quant); - - // Rotary embedding - // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. - ops.def( - "rotary_embedding(Tensor positions, Tensor! query," - " Tensor!? key, int head_size," - " Tensor cos_sin_cache, bool is_neox, int " - "rope_dim_offset=0, bool inverse=False) -> ()"); - ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); - // Quantization ops #ifndef USE_ROCM diff --git a/csrc/type_convert.cuh b/csrc/type_convert.cuh index 2678f69e19b6..9d939bb828fc 100644 --- a/csrc/type_convert.cuh +++ b/csrc/type_convert.cuh @@ -1,8 +1,10 @@ #pragma once -#include +#include +#include #ifndef USE_ROCM + #include #include #include #else @@ -191,4 +193,4 @@ struct alignas(16) _f16Vec { return result; } }; -} // namespace vllm \ No newline at end of file +} // namespace vllm