diff --git a/CMakeLists.txt b/CMakeLists.txt index dd6ebce34be0..40f05447d05c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -288,18 +288,13 @@ set(VLLM_EXT_SRC "csrc/attention/merge_attn_states.cu" "csrc/attention/vertical_slash_index.cu" "csrc/pos_encoding_kernels.cu" - "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" "csrc/fused_qknorm_rope_kernel.cu" "csrc/layernorm_quant_kernels.cu" "csrc/sampler.cu" "csrc/topk.cu" "csrc/cuda_view.cu" - "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/w8a8/int8/scaled_quant.cu" - "csrc/quantization/w8a8/fp8/common.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" - "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/custom_all_reduce.cu" @@ -339,7 +334,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC - "csrc/quantization/awq/gemm_kernels.cu" "csrc/cutlass_extensions/common.cpp" "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu") @@ -472,46 +466,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") " in CUDA target architectures") endif() - # Only build AllSpark kernels if we are building for at least some compatible archs. - cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") - if (ALLSPARK_ARCHS) - set(ALLSPARK_SRCS - "csrc/quantization/gptq_allspark/allspark_repack.cu" - "csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") - set_gencode_flags_for_srcs( - SRCS "${ALLSPARK_SRCS}" - CUDA_ARCHS "${ALLSPARK_ARCHS}") - list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}") - message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") - else() - message(STATUS "Not building AllSpark kernels as no compatible archs found" - " in CUDA target architectures") - endif() - - # CUTLASS MLA Archs and flags - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) - set(SRCS - "csrc/attention/mla/sm100_cutlass_mla_kernel.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${MLA_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1") - # Add MLA-specific include directories only to MLA source files - set_source_files_properties(${SRCS} - PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common") - message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}") - else() - message(STATUS "Not building CUTLASS MLA as no compatible archs were found.") - # clear MLA_ARCHS - set(MLA_ARCHS) - endif() - # Expert-specialization MXFP8 blockscaled grouped kernels (SM100+). if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") @@ -539,24 +493,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() - # DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later) - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_FUSED_A_GEMM_ARCHS) - set(DSV3_FUSED_A_GEMM_SRC "csrc/dsv3_fused_a_gemm.cu") - set_gencode_flags_for_srcs( - SRCS "${DSV3_FUSED_A_GEMM_SRC}" - CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}") - list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC}) - message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}") - else() - message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found " - "in CUDA target architectures.") - endif() - # # Machete kernels @@ -628,16 +564,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() - # Hadacore kernels - cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") - if(HADACORE_ARCHS) - set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${HADACORE_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - message(STATUS "Building hadacore") - endif() # if CUDA endif endif() @@ -669,31 +595,66 @@ define_extension_target( # Setting this variable sidesteps the issue by calling the driver directly. target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) -# add OR VLLM_GPU_LANG STREQUAL "HIP" here once -# https://github.com/vllm-project/vllm/issues/35163 is resolved -if(VLLM_GPU_LANG STREQUAL "CUDA") +if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") # # _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY) # set(VLLM_STABLE_EXT_SRC "csrc/libtorch_stable/torch_bindings.cpp" - "csrc/cutlass_extensions/common.cpp" - "csrc/cuda_utils_kernels.cu" - "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu" - "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu" - "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu") + "csrc/libtorch_stable/activation_kernels.cu" + "csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu" + "csrc/libtorch_stable/quantization/w8a8/fp8/common.cu" + "csrc/libtorch_stable/quantization/gptq/q_gemm.cu" + "csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_STABLE_EXT_SRC + "csrc/cuda_utils_kernels.cu" + "csrc/cutlass_extensions/common.cpp" + "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu" + "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu" + "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/libtorch_stable/permute_cols.cu" "csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu" - "csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu") - endif() + "csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu" + "csrc/libtorch_stable/quantization/awq/gemm_kernels.cu") - if(VLLM_GPU_LANG STREQUAL "CUDA") set_gencode_flags_for_srcs( SRCS "${VLLM_STABLE_EXT_SRC}" CUDA_ARCHS "${CUDA_ARCHS}") + + # DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_FUSED_A_GEMM_ARCHS) + set(SRCS "csrc/libtorch_stable/dsv3_fused_a_gemm.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}") + else() + message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found " + "in CUDA target architectures.") + endif() + + # Only build AllSpark kernels if we are building for at least some compatible archs. + cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") + if (ALLSPARK_ARCHS) + set(SRCS + "csrc/libtorch_stable/quantization/gptq_allspark/allspark_repack.cu" + "csrc/libtorch_stable/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${ALLSPARK_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") + else() + message(STATUS "Not building AllSpark kernels as no compatible archs found" + " in CUDA target architectures") endif() # @@ -989,6 +950,44 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # CUTLASS MLA Archs and flags + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) + set(SRCS + "csrc/libtorch_stable/attention/mla/sm100_cutlass_mla_kernel.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${MLA_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1") + # Add MLA-specific include directories only to MLA source files + set_source_files_properties(${SRCS} + PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common") + message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}") + else() + message(STATUS "Not building CUTLASS MLA as no compatible archs were found.") + # clear MLA_ARCHS + set(MLA_ARCHS) + endif() + + # Hadacore kernels + cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + if(HADACORE_ARCHS) + set(SRCS "csrc/libtorch_stable/quantization/hadamard/hadacore/hadamard_transform_cuda.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${HADACORE_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + message(STATUS "Building hadacore") + endif() + + # if CUDA endif + endif() + message(STATUS "Enabling C_stable extension.") define_extension_target( _C_stable_libtorch @@ -1008,13 +1007,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") target_compile_definitions(_C_stable_libtorch PRIVATE TORCH_TARGET_VERSION=0x020A000000000000ULL) - # Needed to use cuda APIs from C-shim - target_compile_definitions(_C_stable_libtorch PRIVATE - USE_CUDA) + # Needed to use cuda/hip APIs from C-shim + if(VLLM_GPU_LANG STREQUAL "CUDA") + target_compile_definitions(_C_stable_libtorch PRIVATE USE_CUDA) + # Needed by CUTLASS kernels + target_compile_definitions(_C_stable_libtorch PRIVATE + CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) + elseif(VLLM_GPU_LANG STREQUAL "HIP") + target_compile_definitions(_C_stable_libtorch PRIVATE USE_ROCM) + endif() - # Needed by CUTLASS kernels - target_compile_definitions(_C_stable_libtorch PRIVATE - CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) + # On ROCm, _C_stable_libtorch calls raw HIP APIs (e.g. hipGetDevice in + # get_device_prop()) which must resolve to the same libamdhip64.so that + # PyTorch uses. When PyTorch bundles its own copy (pip/conda wheels), + # the raw HIP calls would otherwise resolve to the system ROCm copy, + # initializing a second HIP runtime that corrupts device state (wrong + # device on DeviceGuard, core dumps on multi-GPU tests). + # + # If PyTorch doesn't bundle libamdhip64 (built from source against system + # ROCm), there is only one copy in the process and no action is needed — + # the HIP compiler already links the system libamdhip64 automatically. + if(VLLM_GPU_LANG STREQUAL "HIP") + find_library(_STABLE_TORCH_AMDHIP64 amdhip64 + PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH) + if(_STABLE_TORCH_AMDHIP64) + message(STATUS "Found PyTorch-bundled libamdhip64 at ${_STABLE_TORCH_AMDHIP64}") + target_link_libraries(_C_stable_libtorch PRIVATE ${_STABLE_TORCH_AMDHIP64}) + endif() + endif() endif() # diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index 68a8750f583b..b6f39ed795f3 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -1,7 +1,13 @@ #pragma once -// For TORCH_CHECK -#include +#include +#include +#include +#include +#include + +// For STD_TORCH_CHECK +#include namespace vllm { @@ -45,7 +51,7 @@ class ScalarType { // IEEE 754 compliant floating point type static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { - TORCH_CHECK(mantissa > 0 && exponent > 0); + STD_TORCH_CHECK(mantissa > 0 && exponent > 0); return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); } @@ -53,11 +59,12 @@ class ScalarType { static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { - TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); - TORCH_CHECK(mantissa > 0 && exponent > 0); - TORCH_CHECK(nan_repr != NAN_IEEE_754, - "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions"); + STD_TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + STD_TORCH_CHECK(mantissa > 0 && exponent > 0); + STD_TORCH_CHECK( + nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); } @@ -176,8 +183,8 @@ class ScalarType { private: double _floating_point_max() const { - TORCH_CHECK(mantissa <= 52 && exponent <= 11, - "Cannot represent max/min as a double for type ", str()); + STD_TORCH_CHECK(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { @@ -186,8 +193,8 @@ class ScalarType { uint64_t max_exponent = (uint64_t(1) << exponent) - 2; if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { - TORCH_CHECK(exponent < 11, - "Cannot represent max/min as a double for type ", str()); + STD_TORCH_CHECK(exponent < 11, + "Cannot represent max/min as a double for type ", str()); max_exponent += 1; } @@ -216,16 +223,17 @@ class ScalarType { if (is_floating_point()) { return {_floating_point_max()}; } else { - TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), - "Cannot represent max as a int64_t"); + STD_TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), + "Cannot represent max as a int64_t"); return {(int64_t(1) << mantissa) - 1}; } } constexpr std::variant _raw_min() const { if (is_floating_point()) { - TORCH_CHECK(is_signed(), - "We currently assume all floating point types are signed"); + STD_TORCH_CHECK( + is_signed(), + "We currently assume all floating point types are signed"); constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); double max = _floating_point_max(); @@ -233,8 +241,8 @@ class ScalarType { uint64_t min_raw = max_raw | sign_bit_double; return {*reinterpret_cast(&min_raw)}; } else { - TORCH_CHECK(!is_signed() || size_bits() <= 64, - "Cannot represent min as a int64_t"); + STD_TORCH_CHECK(!is_signed() || size_bits() <= 64, + "Cannot represent min as a int64_t"); if (is_signed()) { // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 // then perform an arithmetic shift right to set all the bits above diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index 91e181c5856d..efbb09994d25 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -9,6 +9,8 @@ #ifdef USE_ROCM #include + #include + #include #else #include #include diff --git a/csrc/activation_kernels.cu b/csrc/libtorch_stable/activation_kernels.cu similarity index 78% rename from csrc/activation_kernels.cu rename to csrc/libtorch_stable/activation_kernels.cu index 758a77795553..ec5fc2a8cc3e 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/libtorch_stable/activation_kernels.cu @@ -1,12 +1,12 @@ -#include -#include -#include +#include +#include #include -#include "cuda_compat.h" -#include "cuda_vec_utils.cuh" +#include "../cuda_compat.h" +#include "../cuda_vec_utils.cuh" #include "dispatch_utils.h" +#include "torch_utils.h" namespace vllm { @@ -160,57 +160,61 @@ packed_gelu_tanh_kernel(const packed_t& val) { return; \ } \ dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int cc_major = get_device_prop()->major; \ int support_vec = \ (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ - int vec_size = support_vec / at::elementSize(dtype); \ + int vec_size = support_vec / input.element_size(); \ const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + const torch::stable::accelerator::DeviceGuard device_guard( \ + input.get_device_index()); \ + const cudaStream_t stream = get_current_cuda_stream(); \ if (use_vec) { \ dim3 block(std::min(d / vec_size, 1024)); \ if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ vllm::act_and_mul_kernel< \ scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL::Type>, \ - ACT_FIRST, true, true><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ + ACT_FIRST, true, true> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ }); \ } else { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ vllm::act_and_mul_kernel< \ scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL::Type>, \ - ACT_FIRST, true, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ + ACT_FIRST, true, false> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ }); \ } \ } else { \ dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ vllm::act_and_mul_kernel< \ scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL::Type>, \ - ACT_FIRST, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ + ACT_FIRST, false> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ }); \ } -void silu_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void silu_and_mul(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel, true); } -void mul_and_silu(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void mul_and_silu(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., 2 * d] { // The difference between mul_and_silu and silu_and_mul is that mul_and_silu // applies the silu to the latter half of the input. @@ -218,15 +222,15 @@ void mul_and_silu(torch::Tensor& out, // [..., d] false); } -void gelu_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_and_mul(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel, true); } -void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, vllm::packed_gelu_tanh_kernel, true); @@ -377,19 +381,20 @@ __global__ void swigluoai_and_mul_kernel( return; \ } \ dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int cc_major = get_device_prop()->major; \ int support_vec = \ (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ - int vec_size = support_vec / at::elementSize(dtype); \ + int vec_size = support_vec / input.element_size(); \ const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + const torch::stable::accelerator::DeviceGuard device_guard( \ + input.get_device_index()); \ + const cudaStream_t stream = get_current_cuda_stream(); \ if (use_vec) { \ dim3 block(std::min(d / vec_size, 1024)); \ if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES( \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ scalar_t, typename vllm::PackedTypeConverter::Type, \ @@ -397,11 +402,11 @@ __global__ void swigluoai_and_mul_kernel( PACKED_KERNEL< \ typename vllm::PackedTypeConverter::Type>, \ true, true><<>>( \ - out.data_ptr(), input.data_ptr(), d, \ - PARAM); \ + out.mutable_data_ptr(), \ + input.const_data_ptr(), d, PARAM); \ }); \ } else { \ - VLLM_DISPATCH_FLOATING_TYPES( \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ scalar_t, typename vllm::PackedTypeConverter::Type, \ @@ -409,45 +414,49 @@ __global__ void swigluoai_and_mul_kernel( PACKED_KERNEL< \ typename vllm::PackedTypeConverter::Type>, \ true, false><<>>( \ - out.data_ptr(), input.data_ptr(), d, \ - PARAM); \ + out.mutable_data_ptr(), \ + input.const_data_ptr(), d, PARAM); \ }); \ } \ } else { \ dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \ - vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::PackedTypeConverter::Type, \ - KERNEL, \ - PACKED_KERNEL::Type>, \ - false><<>>( \ - out.data_ptr(), input.data_ptr(), d, PARAM); \ - }); \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ + dtype, "act_and_mul_kernel_with_param", [&] { \ + vllm::act_and_mul_kernel_with_param< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL< \ + typename vllm::PackedTypeConverter::Type>, \ + false><<>>( \ + out.mutable_data_ptr(), \ + input.const_data_ptr(), d, PARAM); \ + }); \ } -#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \ - vllm::swigluoai_and_mul_kernel> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d, ALPHA, \ - LIMIT); \ +#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const torch::stable::accelerator::DeviceGuard device_guard( \ + input.get_device_index()); \ + const cudaStream_t stream = get_current_cuda_stream(); \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \ + vllm::swigluoai_and_mul_kernel> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d, \ + ALPHA, LIMIT); \ }); -void fatrelu_and_mul(torch::Tensor& out, // [..., d], - torch::Tensor& input, // [..., 2 * d] +void fatrelu_and_mul(torch::stable::Tensor& out, // [..., d], + torch::stable::Tensor& input, // [..., 2 * d] double threshold) { LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM( vllm::fatrelu_kernel, vllm::packed_fatrelu_kernel, threshold); } -void swigluoai_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., 2 * d] +void swigluoai_and_mul(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input, // [..., 2 * d] double alpha, double limit) { LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit); } @@ -502,45 +511,46 @@ __global__ void activation_kernel( } // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - auto dtype = input.scalar_type(); \ - int d = input.size(-1); \ - int64_t num_tokens = input.numel() / input.size(-1); \ - if (num_tokens == 0) { \ - return; \ - } \ - dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = \ - (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ - ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ - : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ - int vec_size = support_vec / at::elementSize(dtype); \ - const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - if (use_vec) { \ - dim3 block(std::min(d / vec_size, 1024)); \ - if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ - vllm::activation_kernel, true, true> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); \ - } else { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ - vllm::activation_kernel, true, false> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); \ - } \ - } else { \ - dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ - vllm::activation_kernel, false> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); \ +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = get_device_prop()->major; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ + int vec_size = support_vec / input.element_size(); \ + const bool use_vec = (d % vec_size == 0); \ + const torch::stable::accelerator::DeviceGuard device_guard( \ + input.get_device_index()); \ + const cudaStream_t stream = get_current_cuda_stream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, true, true> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ + }); \ + } else { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, true, false> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, false> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ + }); \ } namespace vllm { @@ -568,20 +578,20 @@ __device__ __forceinline__ T gelu_quick_kernel(const T& x) { } // namespace vllm -void gelu_new(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_new(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_fast(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } -void gelu_quick(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_quick(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel); } diff --git a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp b/csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp similarity index 100% rename from csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp rename to csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp b/csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp similarity index 100% rename from csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp rename to csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp similarity index 100% rename from csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp rename to csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp b/csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp similarity index 100% rename from csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp rename to csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/libtorch_stable/attention/mla/sm100_cutlass_mla_kernel.cu similarity index 77% rename from csrc/attention/mla/sm100_cutlass_mla_kernel.cu rename to csrc/libtorch_stable/attention/mla/sm100_cutlass_mla_kernel.cu index d1874515cc8f..55d75383476e 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/libtorch_stable/attention/mla/sm100_cutlass_mla_kernel.cu @@ -18,13 +18,12 @@ limitations under the License. * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 * by Alcanderian JieXin Liang */ -#include "core/registration.h" +#include "libtorch_stable/torch_utils.h" + +#include -#include -#include #include #include -#include #include #include @@ -35,27 +34,27 @@ limitations under the License. // clang-format off #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 void sm100_cutlass_mla_decode( - torch::Tensor const& out, - torch::Tensor const& lse, - torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, - torch::Tensor const& workspace, + torch::stable::Tensor const& out, + torch::stable::Tensor const& lse, + torch::stable::Tensor const& q_nope, + torch::stable::Tensor const& q_pe, + torch::stable::Tensor const& kv_c_and_k_pe_cache, + torch::stable::Tensor const& seq_lens, + torch::stable::Tensor const& page_table, + torch::stable::Tensor const& workspace, double sm_scale, int64_t num_kv_splits) { - TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); + STD_TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); } int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { - TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size"); + STD_TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size"); } #else #define CUTLASS_CHECK(status) \ { \ cutlass::Status error = status; \ - TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + STD_TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ } using namespace cute; @@ -100,23 +99,23 @@ struct MlaSm100 { template typename T::Fmha::Arguments args_from_options( - at::Tensor const& out, - at::Tensor const& lse, - at::Tensor const& q_nope, - at::Tensor const& q_pe, - at::Tensor const& kv_c_and_k_pe_cache, - at::Tensor const& seq_lens, - at::Tensor const& page_table, + torch::stable::Tensor const& out, + torch::stable::Tensor const& lse, + torch::stable::Tensor const& q_nope, + torch::stable::Tensor const& q_pe, + torch::stable::Tensor const& kv_c_and_k_pe_cache, + torch::stable::Tensor const& seq_lens, + torch::stable::Tensor const& page_table, double sm_scale, int64_t num_kv_splits) { cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = q_nope.device().index(); + hw_info.device_id = q_nope.get_device_index(); hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - int batches = q_nope.sizes()[0]; - int page_count_per_seq = page_table.sizes()[1]; - int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; - int page_size = kv_c_and_k_pe_cache.sizes()[1]; + int batches = q_nope.size(0); + int page_count_per_seq = page_table.size(1); + int page_count_total = kv_c_and_k_pe_cache.size(0); + int page_size = kv_c_and_k_pe_cache.size(1); int max_seq_len = page_size * page_count_per_seq; using TileShapeH = typename T::TileShapeH; using TileShapeD = typename T::TileShapeD; @@ -186,14 +185,14 @@ typename T::Fmha::Arguments args_from_options( template void runMla( - at::Tensor const& out, - at::Tensor const& lse, - at::Tensor const& q_nope, - at::Tensor const& q_pe, - at::Tensor const& kv_c_and_k_pe_cache, - at::Tensor const& seq_lens, - at::Tensor const& page_table, - at::Tensor const& workspace, + torch::stable::Tensor const& out, + torch::stable::Tensor const& lse, + torch::stable::Tensor const& q_nope, + torch::stable::Tensor const& q_pe, + torch::stable::Tensor const& kv_c_and_k_pe_cache, + torch::stable::Tensor const& seq_lens, + torch::stable::Tensor const& page_table, + torch::stable::Tensor const& workspace, double sm_scale, int64_t num_kv_splits, cudaStream_t stream) { @@ -220,37 +219,37 @@ void runMla( }() void sm100_cutlass_mla_decode( - torch::Tensor const& out, - torch::Tensor const& lse, - torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, - torch::Tensor const& workspace, + torch::stable::Tensor const& out, + torch::stable::Tensor const& lse, + torch::stable::Tensor const& q_nope, + torch::stable::Tensor const& q_pe, + torch::stable::Tensor const& kv_c_and_k_pe_cache, + torch::stable::Tensor const& seq_lens, + torch::stable::Tensor const& page_table, + torch::stable::Tensor const& workspace, double sm_scale, int64_t num_kv_splits) { - auto in_dtype = q_nope.dtype(); - at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); - const int page_size = kv_c_and_k_pe_cache.sizes()[1]; - + auto in_dtype = q_nope.scalar_type(); + torch::stable::accelerator::DeviceGuard device_guard(q_nope.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(q_nope.get_device_index()); + const int page_size = kv_c_and_k_pe_cache.size(1); + // NOTE(alcanderian): IsPersistent has bug with manual split_kv. // Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8) // Maybe per batch split kv will fix this. DISPATCH_BOOL(page_size == 128, IsPaged128, [&] { DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { - if (in_dtype == at::ScalarType::Half) { + if (in_dtype == torch::headeronly::ScalarType::Half) { runMla>( out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); - } else if (in_dtype == at::ScalarType::BFloat16) { + } else if (in_dtype == torch::headeronly::ScalarType::BFloat16) { runMla>( out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); - } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { + } else if (in_dtype == torch::headeronly::ScalarType::Float8_e4m3fn) { runMla>( out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else { - TORCH_CHECK(false, "Unsupported input data type of MLA"); + STD_TORCH_CHECK(false, "Unsupported input data type of MLA"); } return true; }); @@ -280,12 +279,12 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba #endif -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode); +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("sm100_cutlass_mla_decode", TORCH_BOX(&sm100_cutlass_mla_decode)); } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) { - m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size); +STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, m) { + m.impl("sm100_cutlass_mla_get_workspace_size", TORCH_BOX(&sm100_cutlass_mla_get_workspace_size)); } // clang-format on diff --git a/csrc/dsv3_fused_a_gemm.cu b/csrc/libtorch_stable/dsv3_fused_a_gemm.cu similarity index 93% rename from csrc/dsv3_fused_a_gemm.cu rename to csrc/libtorch_stable/dsv3_fused_a_gemm.cu index 65dff9c84bab..bdf749ddfcf9 100644 --- a/csrc/dsv3_fused_a_gemm.cu +++ b/csrc/libtorch_stable/dsv3_fused_a_gemm.cu @@ -20,13 +20,15 @@ * limitations under the License. */ -#include -#include -#include -#include -#include +#include +#include +#include #include "core/registration.h" +#include "libtorch_stable/torch_utils.h" + +#include +#include #include #include @@ -34,7 +36,7 @@ namespace { inline int getSMVersion() { - auto* props = at::cuda::getCurrentDeviceProperties(); + auto* props = get_device_prop(); return props->major * 10 + props->minor; } @@ -700,37 +702,40 @@ template void invokeFusedAGemm<__nv_bfloat16, 7168, 2112, 16>( __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, int num_tokens, cudaStream_t); -void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, - torch::Tensor const& mat_b) { - TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2 && output.dim() == 2); +void dsv3_fused_a_gemm(torch::stable::Tensor& output, + torch::stable::Tensor const& mat_a, + torch::stable::Tensor const& mat_b) { + STD_TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2 && output.dim() == 2); int const num_tokens = mat_a.size(0); int const hd_in = mat_a.size(1); int const hd_out = mat_b.size(1); constexpr int kHdIn = 7168; constexpr int kHdOut = 2112; - TORCH_CHECK(num_tokens >= 1 && num_tokens <= 16, - "required 1 <= mat_a.shape[0] <= 16") - TORCH_CHECK(hd_in == kHdIn, "required mat_a.shape[1] == 7168") - TORCH_CHECK(hd_out == kHdOut, "required mat_b.shape[1] == 2112") - TORCH_CHECK(output.size(0) == num_tokens, - "required output.shape[0] == mat_a.shape[0]") - TORCH_CHECK(output.size(1) == hd_out, - "required output.shape[1] == mat_b.shape[1]") - - TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); - TORCH_CHECK(output.stride(1) == 1, "output must be a row major tensor"); - TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); - - TORCH_CHECK(mat_a.scalar_type() == torch::kBFloat16 && - mat_b.scalar_type() == torch::kBFloat16, - "Only BFloat16 input dtype is supported") - TORCH_CHECK(output.scalar_type() == torch::kBFloat16, - "Only BFloat16 output dtype is supported") - - TORCH_CHECK(getSMVersion() >= 90, "required CUDA ARCH >= SM_90"); - - auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); + STD_TORCH_CHECK(num_tokens >= 1 && num_tokens <= 16, + "required 1 <= mat_a.shape[0] <= 16"); + STD_TORCH_CHECK(hd_in == kHdIn, "required mat_a.shape[1] == 7168"); + STD_TORCH_CHECK(hd_out == kHdOut, "required mat_b.shape[1] == 2112"); + STD_TORCH_CHECK(output.size(0) == num_tokens, + "required output.shape[0] == mat_a.shape[0]"); + STD_TORCH_CHECK(output.size(1) == hd_out, + "required output.shape[1] == mat_b.shape[1]"); + + STD_TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); + STD_TORCH_CHECK(output.stride(1) == 1, "output must be a row major tensor"); + STD_TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); + + STD_TORCH_CHECK( + mat_a.scalar_type() == torch::headeronly::ScalarType::BFloat16 && + mat_b.scalar_type() == torch::headeronly::ScalarType::BFloat16, + "Only BFloat16 input dtype is supported"); + STD_TORCH_CHECK( + output.scalar_type() == torch::headeronly::ScalarType::BFloat16, + "Only BFloat16 output dtype is supported"); + + STD_TORCH_CHECK(getSMVersion() >= 90, "required CUDA ARCH >= SM_90"); + + auto stream = get_current_cuda_stream(mat_a.get_device_index()); if (num_tokens <= 8) { invokeFusedAGemm<__nv_bfloat16, kHdIn, kHdOut, 8>( reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()), @@ -746,6 +751,6 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, } } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("dsv3_fused_a_gemm", &dsv3_fused_a_gemm); +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("dsv3_fused_a_gemm", TORCH_BOX(&dsv3_fused_a_gemm)); } diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h index 8153102c598a..f1522c39a402 100644 --- a/csrc/libtorch_stable/ops.h +++ b/csrc/libtorch_stable/ops.h @@ -134,4 +134,107 @@ void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out, torch::stable::Tensor& input, torch::stable::Tensor& input_global_scale); +// AWQ ops +torch::stable::Tensor awq_gemm(torch::stable::Tensor _in_feats, + torch::stable::Tensor _kernel, + torch::stable::Tensor _scaling_factors, + torch::stable::Tensor _zeros, + int64_t split_k_iters); + +torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel, + torch::stable::Tensor _scaling_factors, + torch::stable::Tensor _zeros, + int64_t split_k_iters, int64_t thx, + int64_t thy); + +// DSV3 fused A GEMM: conditionally compiled so declaration and impl +// registration are in the source file (dsv3_fused_a_gemm.cu) + +// AllSpark ops: declarations are in the source files +// (allspark_repack.cu and allspark_qgemm_w8a16.cu) + #endif + +torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x, + bool inplace); + +// Activation kernels (shared CUDA/ROCm) +void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input); +void mul_and_silu(torch::stable::Tensor& out, torch::stable::Tensor& input); +void gelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input); +void gelu_tanh_and_mul(torch::stable::Tensor& out, + torch::stable::Tensor& input); +void fatrelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input, + double threshold); +void swigluoai_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input, + double alpha = 1.702, double limit = 7.0); +void gelu_new(torch::stable::Tensor& out, torch::stable::Tensor& input); +void gelu_fast(torch::stable::Tensor& out, torch::stable::Tensor& input); +void gelu_quick(torch::stable::Tensor& out, torch::stable::Tensor& input); + +// INT8 quantization kernels (shared CUDA/ROCm) +void static_scaled_int8_quant(torch::stable::Tensor& out, + torch::stable::Tensor const& input, + torch::stable::Tensor const& scale, + std::optional const& azp); + +void dynamic_scaled_int8_quant(torch::stable::Tensor& out, + torch::stable::Tensor const& input, + torch::stable::Tensor& scales, + std::optional const& azp); + +// FP8 quantization kernels (shared CUDA/ROCm) +void static_scaled_fp8_quant( + torch::stable::Tensor& out, torch::stable::Tensor const& input, + torch::stable::Tensor const& scale, + std::optional group_shape = + std::nullopt); + +void dynamic_scaled_fp8_quant(torch::stable::Tensor& out, + torch::stable::Tensor const& input, + torch::stable::Tensor& scale); + +void dynamic_per_token_scaled_fp8_quant( + torch::stable::Tensor& out, torch::stable::Tensor const& input, + torch::stable::Tensor& scale, + std::optional const& scale_ub); + +// GPTQ kernels (shared CUDA/ROCm) +torch::stable::Tensor gptq_gemm(torch::stable::Tensor a, + torch::stable::Tensor b_q_weight, + torch::stable::Tensor b_gptq_qzeros, + torch::stable::Tensor b_gptq_scales, + torch::stable::Tensor b_g_idx, bool use_exllama, + bool use_v2_format, int64_t bit); + +void gptq_shuffle(torch::stable::Tensor q_weight, torch::stable::Tensor q_perm, + int64_t bit); + +// GGML kernels (shared CUDA/ROCm) +torch::stable::Tensor ggml_dequantize( + torch::stable::Tensor W, int64_t type, int64_t m, int64_t n, + std::optional const& dtype); + +torch::stable::Tensor ggml_mul_mat_vec_a8(torch::stable::Tensor W, + torch::stable::Tensor X, int64_t type, + int64_t row); + +torch::stable::Tensor ggml_mul_mat_a8(torch::stable::Tensor W, + torch::stable::Tensor X, int64_t type, + int64_t row); + +torch::stable::Tensor ggml_moe_a8(torch::stable::Tensor X, + torch::stable::Tensor W, + torch::stable::Tensor sorted_token_ids, + torch::stable::Tensor expert_ids, + torch::stable::Tensor num_tokens_post_padded, + int64_t type, int64_t row, int64_t top_k, + int64_t tokens); + +torch::stable::Tensor ggml_moe_a8_vec(torch::stable::Tensor X, + torch::stable::Tensor W, + torch::stable::Tensor topk_ids, + int64_t top_k, int64_t type, int64_t row, + int64_t tokens); + +int64_t ggml_moe_get_block_size(int64_t type); diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/libtorch_stable/quantization/awq/dequantize.cuh similarity index 100% rename from csrc/quantization/awq/dequantize.cuh rename to csrc/libtorch_stable/quantization/awq/dequantize.cuh diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/libtorch_stable/quantization/awq/gemm_kernels.cu similarity index 89% rename from csrc/quantization/awq/gemm_kernels.cu rename to csrc/libtorch_stable/quantization/awq/gemm_kernels.cu index 53c47679cdd7..c3702c52efcb 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/libtorch_stable/quantization/awq/gemm_kernels.cu @@ -7,10 +7,11 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ -#include -#include +#include +#include +#include "libtorch_stable/torch_utils.h" -#include "dequantize.cuh" +#include "libtorch_stable/quantization/awq/dequantize.cuh" #include @@ -410,10 +411,11 @@ __global__ void __launch_bounds__(64) } // namespace awq } // namespace vllm -torch::Tensor awq_dequantize(torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, - int64_t thx, int64_t thy) { +torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel, + torch::stable::Tensor _scaling_factors, + torch::stable::Tensor _zeros, + int64_t split_k_iters, int64_t thx, + int64_t thy) { int in_c = _kernel.size(0); int qout_c = _kernel.size(1); int out_c = qout_c * 8; @@ -437,23 +439,24 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, y_blocks = (int)(in_c / 8); } - const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + const torch::stable::accelerator::DeviceGuard device_guard( + _scaling_factors.get_device_index()); - auto options = torch::TensorOptions() - .dtype(_scaling_factors.dtype()) - .device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + auto _de_kernel = + torch::stable::empty({in_c, out_c}, _scaling_factors.scalar_type(), + std::nullopt, _scaling_factors.device()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = - reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto kernel = reinterpret_cast(_kernel.mutable_data_ptr()); + auto de_kernel = reinterpret_cast( + _de_kernel.mutable_data_ptr()); + auto scaling_factors = reinterpret_cast( + _scaling_factors.mutable_data_ptr()); + auto zeros = reinterpret_cast(_zeros.mutable_data_ptr()); dim3 num_blocks(x_blocks, y_blocks); dim3 threads_per_block(x_thread, y_thread); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); vllm::awq::dequantize_weights<<>>( kernel, scaling_factors, zeros, de_kernel, G); @@ -466,27 +469,30 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // assume that batch_size < 16 for now -torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, - int64_t split_k_iters) { +torch::stable::Tensor awq_gemm(torch::stable::Tensor _in_feats, + torch::stable::Tensor _kernel, + torch::stable::Tensor _scaling_factors, + torch::stable::Tensor _zeros, + int64_t split_k_iters) { int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + const torch::stable::accelerator::DeviceGuard device_guard( + _in_feats.get_device_index()); - auto options = torch::TensorOptions() - .dtype(_in_feats.dtype()) - .device(_in_feats.device()); - at::Tensor _out_feats = - torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + auto _out_feats = torch::stable::empty( + {split_k_iters, num_in_feats, _kernel.size(1) * 8}, + _in_feats.scalar_type(), std::nullopt, _in_feats.device()); int num_out_feats = _out_feats.size(-2); int num_out_channels = _out_feats.size(-1); - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = - reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto in_feats = reinterpret_cast( + _in_feats.mutable_data_ptr()); + auto kernel = reinterpret_cast(_kernel.mutable_data_ptr()); + auto out_feats = reinterpret_cast( + _out_feats.mutable_data_ptr()); + auto scaling_factors = reinterpret_cast( + _scaling_factors.mutable_data_ptr()); + auto zeros = reinterpret_cast(_zeros.mutable_data_ptr()); int group_size = num_in_channels / _scaling_factors.size(0); if (num_out_channels % 64 != 0) @@ -498,7 +504,7 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, if (num_out_channels % group_size != 0) throw std::invalid_argument("OC is not multiple of Group size"); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); if (num_out_channels % 128 == 0) { int j_factors1 = num_out_channels / 128 / 1; dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); @@ -522,5 +528,5 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); } - return _out_feats.sum(0); + return torch::stable::sum(_out_feats, 0); } diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu similarity index 61% rename from csrc/quantization/gguf/gguf_kernel.cu rename to csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu index 76fe73e95040..0fdfcafab8c0 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu @@ -1,17 +1,20 @@ #include #include -#include -#include +#include "../../../cuda_compat.h" +#include "../../dispatch_utils.h" +#include "../../torch_utils.h" -#include "../../cuda_compat.h" -#include "dispatch_utils.h" +#include -#include "ggml-common.h" -#include "vecdotq.cuh" -#include "dequantize.cuh" -#include "mmvq.cuh" -#include "mmq.cuh" +// NOTE: These headers are intentionally kept in csrc/quantization/gguf/ (not +// moved to libtorch_stable) to avoid unnecessary reformatting that would break +// git rename detection and pollute blame history. +#include "../../../quantization/gguf/ggml-common.h" +#include "../../../quantization/gguf/vecdotq.cuh" +#include "../../../quantization/gguf/dequantize.cuh" +#include "../../../quantization/gguf/mmvq.cuh" +#include "../../../quantization/gguf/mmq.cuh" #include "moe.cuh" #include "moe_vec.cuh" @@ -71,16 +74,17 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx, } } -torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight - int64_t type, int64_t m, int64_t n, - std::optional const& dtype) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(W)); - auto dtype_ = dtype.value_or(torch::kFloat16); - auto options = torch::TensorOptions().dtype(dtype_).device(W.device()); - at::Tensor DW = torch::empty({m, n}, options); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); +torch::stable::Tensor ggml_dequantize( + torch::stable::Tensor W, // quant weight + int64_t type, int64_t m, int64_t n, + std::optional const& dtype) { + const torch::stable::accelerator::DeviceGuard device_guard( + W.get_device_index()); + auto dtype_ = dtype.value_or(torch::headeronly::ScalarType::Half); + auto DW = torch::stable::empty({m, n}, dtype_, std::nullopt, W.device()); + cudaStream_t stream = get_current_cuda_stream(); - VLLM_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] { + VLLM_STABLE_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] { auto to_cuda = ggml_get_to_cuda(type); to_cuda((void*)W.data_ptr(), (scalar_t*)DW.data_ptr(), m * n, stream); }); @@ -88,135 +92,142 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight return DW; } -torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight - torch::Tensor X, // input - int64_t type, int64_t row) { +torch::stable::Tensor ggml_mul_mat_vec_a8( + torch::stable::Tensor W, // quant weight + torch::stable::Tensor X, // input + int64_t type, int64_t row) { int col = X.sizes()[1]; int vecs = X.sizes()[0]; const int padded = (col + 512 - 1) / 512 * 512; - const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); - auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); - at::Tensor Y = torch::empty({vecs, row}, options); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); - at::Tensor quant_X = torch::empty({vecs, padded / 32 * 9}, options); - VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] { - quantize_row_q8_1_cuda( - (scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, vecs, stream); - switch (type) { - case 2: - mul_mat_vec_q4_0_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 3: - mul_mat_vec_q4_1_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 6: - mul_mat_vec_q5_0_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 7: - mul_mat_vec_q5_1_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 8: - mul_mat_vec_q8_0_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 10: - mul_mat_vec_q2_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 11: - mul_mat_vec_q3_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 12: - mul_mat_vec_q4_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 13: - mul_mat_vec_q5_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 14: - mul_mat_vec_q6_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 16: - mul_mat_vec_iq2_xxs_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 17: - mul_mat_vec_iq2_xs_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 18: - mul_mat_vec_iq3_xxs_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 19: - mul_mat_vec_iq1_s_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 20: - mul_mat_vec_iq4_nl_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 21: - mul_mat_vec_iq3_s_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 22: - mul_mat_vec_iq2_s_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 23: - mul_mat_vec_iq4_xs_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 29: - mul_mat_vec_iq1_m_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - } - }); + const torch::stable::accelerator::DeviceGuard device_guard( + X.get_device_index()); + auto Y = torch::stable::empty({vecs, row}, X.scalar_type(), std::nullopt, + W.device()); + cudaStream_t stream = get_current_cuda_stream(); + auto quant_X = torch::stable::empty({vecs, padded / 32 * 9}, + torch::headeronly::ScalarType::Int, + std::nullopt, W.device()); + VLLM_STABLE_DISPATCH_FLOATING_TYPES( + X.scalar_type(), "ggml_mul_mat_vec_a8", [&] { + quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), + (void*)quant_X.data_ptr(), col, vecs, + stream); + switch (type) { + case 2: + mul_mat_vec_q4_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 3: + mul_mat_vec_q4_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 6: + mul_mat_vec_q5_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 7: + mul_mat_vec_q5_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 8: + mul_mat_vec_q8_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 10: + mul_mat_vec_q2_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 11: + mul_mat_vec_q3_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 12: + mul_mat_vec_q4_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 13: + mul_mat_vec_q5_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 14: + mul_mat_vec_q6_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 16: + mul_mat_vec_iq2_xxs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 17: + mul_mat_vec_iq2_xs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 18: + mul_mat_vec_iq3_xxs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 19: + mul_mat_vec_iq1_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 20: + mul_mat_vec_iq4_nl_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 21: + mul_mat_vec_iq3_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 22: + mul_mat_vec_iq2_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 23: + mul_mat_vec_iq4_xs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 29: + mul_mat_vec_iq1_m_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + } + }); return Y; } -torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight - torch::Tensor X, // input - int64_t type, int64_t row) { +torch::stable::Tensor ggml_mul_mat_a8(torch::stable::Tensor W, // quant weight + torch::stable::Tensor X, // input + int64_t type, int64_t row) { int col = X.sizes()[1]; int padded = (col + 512 - 1) / 512 * 512; int batch = X.sizes()[0]; - const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); - auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); - at::Tensor Y = torch::empty({batch, row}, options); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); - at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options); - VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] { + const torch::stable::accelerator::DeviceGuard device_guard( + X.get_device_index()); + auto Y = torch::stable::empty({batch, row}, X.scalar_type(), std::nullopt, + W.device()); + cudaStream_t stream = get_current_cuda_stream(); + auto quant_X = torch::stable::empty({batch, padded / 32 * 9}, + torch::headeronly::ScalarType::Int, + std::nullopt, W.device()); + VLLM_STABLE_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] { quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, batch, stream); @@ -276,21 +287,24 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight return Y; } -torch::Tensor ggml_moe_a8(torch::Tensor X, // input - torch::Tensor W, // expert weights - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_padded, int64_t type, - int64_t row, int64_t top_k, int64_t tokens) { +torch::stable::Tensor ggml_moe_a8(torch::stable::Tensor X, // input + torch::stable::Tensor W, // expert weights + torch::stable::Tensor sorted_token_ids, + torch::stable::Tensor expert_ids, + torch::stable::Tensor num_tokens_post_padded, + int64_t type, int64_t row, int64_t top_k, + int64_t tokens) { int col = X.sizes()[1]; int padded = (col + 512 - 1) / 512 * 512; - const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); - auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); - at::Tensor Y = torch::empty({tokens * top_k, row}, options); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); - at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options); - VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] { + const torch::stable::accelerator::DeviceGuard device_guard( + X.get_device_index()); + auto Y = torch::stable::empty({tokens * top_k, row}, X.scalar_type(), + std::nullopt, W.device()); + cudaStream_t stream = get_current_cuda_stream(); + auto quant_X = torch::stable::empty({tokens, padded / 32 * 9}, + torch::headeronly::ScalarType::Int, + std::nullopt, W.device()); + VLLM_STABLE_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] { quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, tokens, stream); switch (type) { @@ -379,19 +393,23 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input return Y; } -torch::Tensor ggml_moe_a8_vec(torch::Tensor X, // input - torch::Tensor W, // expert weights - torch::Tensor topk_ids, int64_t top_k, - int64_t type, int64_t row, int64_t tokens) { +torch::stable::Tensor ggml_moe_a8_vec( + torch::stable::Tensor X, // input + torch::stable::Tensor W, // expert weights + torch::stable::Tensor topk_ids, int64_t top_k, int64_t type, int64_t row, + int64_t tokens) { int col = X.sizes()[1]; const int padded = (col + 512 - 1) / 512 * 512; - const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); - auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); - at::Tensor Y = torch::zeros({tokens * top_k, row}, options); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); - at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options); - VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] { + const torch::stable::accelerator::DeviceGuard device_guard( + X.get_device_index()); + auto Y = torch::stable::empty({tokens * top_k, row}, X.scalar_type(), + std::nullopt, W.device()); + torch::stable::fill_(Y, 0.0); + cudaStream_t stream = get_current_cuda_stream(); + auto quant_X = torch::stable::empty({tokens, padded / 32 * 9}, + torch::headeronly::ScalarType::Int, + std::nullopt, W.device()); + VLLM_STABLE_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] { quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, tokens, stream); diff --git a/csrc/quantization/gguf/moe.cuh b/csrc/libtorch_stable/quantization/gguf/moe.cuh similarity index 100% rename from csrc/quantization/gguf/moe.cuh rename to csrc/libtorch_stable/quantization/gguf/moe.cuh diff --git a/csrc/quantization/gguf/moe_vec.cuh b/csrc/libtorch_stable/quantization/gguf/moe_vec.cuh similarity index 100% rename from csrc/quantization/gguf/moe_vec.cuh rename to csrc/libtorch_stable/quantization/gguf/moe_vec.cuh diff --git a/csrc/quantization/gptq/compat.cuh b/csrc/libtorch_stable/quantization/gptq/compat.cuh similarity index 100% rename from csrc/quantization/gptq/compat.cuh rename to csrc/libtorch_stable/quantization/gptq/compat.cuh diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/libtorch_stable/quantization/gptq/matrix_view.cuh similarity index 100% rename from csrc/quantization/gptq/matrix_view.cuh rename to csrc/libtorch_stable/quantization/gptq/matrix_view.cuh diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/libtorch_stable/quantization/gptq/q_gemm.cu similarity index 97% rename from csrc/quantization/gptq/q_gemm.cu rename to csrc/libtorch_stable/quantization/gptq/q_gemm.cu index 8a29ad5ab2dd..e3f79c5a6b8e 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/libtorch_stable/quantization/gptq/q_gemm.cu @@ -6,9 +6,8 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa #include #include -#include -#include -#include +#include "../../torch_utils.h" +#include #include #include @@ -735,7 +734,7 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); kernel<<>>( a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k, groups, use_v2_format, b_q_perm); @@ -1164,7 +1163,7 @@ void reconstruct_exllama(const uint32_t* b_q_weight, reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); reconstruct_exllama_kernel<<>>( b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, use_v2_format, out); @@ -1376,7 +1375,7 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, kernel = gemm_half_q_half_alt_8bit_kernel; } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); kernel<<>>( (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, size_m, size_k / 32 * bit, size_n, use_v2_format); @@ -1485,7 +1484,7 @@ void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, gridDim.y = DIVIDE(height, 32); } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); kernel<<>>(b_q_weight, b_gptq_scales, b_gptq_qzeros, b_g_idx, height, width, groups, use_v2_format, out); @@ -1794,7 +1793,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, } else if (bit == 8) { kernel = make_sequential_8bit_kernel; } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); kernel<<>>(q_weight, new_qweight, q_perm, width); // Replace qweights @@ -1818,29 +1817,34 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, } else if (bit == 8) { shuffle_kernel = shuffle_8bit_kernel; } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); shuffle_kernel<<>>(q_weight, height, width); } } // namespace gptq } // namespace vllm -torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, bool use_v2_format, int64_t bit) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - at::Tensor c = torch::zeros({a.size(0), b_q_weight.size(1)}, options); - at::Tensor temp_dq = torch::empty( - {b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); +torch::stable::Tensor gptq_gemm(torch::stable::Tensor a, + torch::stable::Tensor b_q_weight, + torch::stable::Tensor b_gptq_qzeros, + torch::stable::Tensor b_gptq_scales, + torch::stable::Tensor b_g_idx, bool use_exllama, + bool use_v2_format, int64_t bit) { + const torch::stable::accelerator::DeviceGuard device_guard( + a.get_device_index()); + auto c = torch::stable::new_zeros(a, {a.size(0), b_q_weight.size(1)}); + auto temp_dq = + torch::stable::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, + a.scalar_type(), std::nullopt, a.device()); vllm::gptq::gemm_half_q_half_cuda( - at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(), + get_current_cuda_blas_handle(), (const half*)a.data_ptr(), (const uint32_t*)b_q_weight.data_ptr(), (const uint32_t*)b_gptq_qzeros.data_ptr(), (const half*)b_gptq_scales.data_ptr(), - b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(), + b_g_idx.device().type() == torch::stable::DeviceType::Meta + ? NULL + : (const int*)b_g_idx.data_ptr(), (half*)c.data_ptr(), (half*)temp_dq.data_ptr(), c.size(0), // m c.size(1), // n @@ -1850,11 +1854,14 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, return c; } -void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); +void gptq_shuffle(torch::stable::Tensor q_weight, torch::stable::Tensor q_perm, + int64_t bit) { + const torch::stable::accelerator::DeviceGuard device_guard( + q_weight.get_device_index()); vllm::gptq::shuffle_exllama_weight( (uint32_t*)q_weight.data_ptr(), - q_perm.device().is_meta() || q_perm.numel() == 0 + q_perm.device().type() == torch::stable::DeviceType::Meta || + q_perm.numel() == 0 ? NULL : (int*)q_perm.data_ptr(), q_weight.size(0) * 32 / bit, q_weight.size(1), bit); diff --git a/csrc/quantization/gptq/qdq_2.cuh b/csrc/libtorch_stable/quantization/gptq/qdq_2.cuh similarity index 100% rename from csrc/quantization/gptq/qdq_2.cuh rename to csrc/libtorch_stable/quantization/gptq/qdq_2.cuh diff --git a/csrc/quantization/gptq/qdq_3.cuh b/csrc/libtorch_stable/quantization/gptq/qdq_3.cuh similarity index 100% rename from csrc/quantization/gptq/qdq_3.cuh rename to csrc/libtorch_stable/quantization/gptq/qdq_3.cuh diff --git a/csrc/quantization/gptq/qdq_4.cuh b/csrc/libtorch_stable/quantization/gptq/qdq_4.cuh similarity index 100% rename from csrc/quantization/gptq/qdq_4.cuh rename to csrc/libtorch_stable/quantization/gptq/qdq_4.cuh diff --git a/csrc/quantization/gptq/qdq_8.cuh b/csrc/libtorch_stable/quantization/gptq/qdq_8.cuh similarity index 100% rename from csrc/quantization/gptq/qdq_8.cuh rename to csrc/libtorch_stable/quantization/gptq/qdq_8.cuh diff --git a/csrc/quantization/gptq/qdq_util.cuh b/csrc/libtorch_stable/quantization/gptq/qdq_util.cuh similarity index 100% rename from csrc/quantization/gptq/qdq_util.cuh rename to csrc/libtorch_stable/quantization/gptq/qdq_util.cuh diff --git a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu b/csrc/libtorch_stable/quantization/gptq_allspark/allspark_qgemm_w8a16.cu similarity index 92% rename from csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu rename to csrc/libtorch_stable/quantization/gptq_allspark/allspark_qgemm_w8a16.cu index e306ff02605b..96dc3ecfc860 100644 --- a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu +++ b/csrc/libtorch_stable/quantization/gptq_allspark/allspark_qgemm_w8a16.cu @@ -1,20 +1,28 @@ #include "allspark_utils.cuh" -#include -#include "core/registration.h" + +#include +#include +#include +#include + #include -at::Tensor as_g_workspace; +#include "core/registration.h" +#include "libtorch_stable/torch_utils.h" + +torch::stable::Tensor as_g_workspace; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -torch::Tensor allspark_w8a16_gemm( - torch::Tensor const& a, torch::Tensor const& b_qweight, - torch::Tensor const& b_scales, std::optional const& b_qzeros, - int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, +torch::stable::Tensor allspark_w8a16_gemm( + torch::stable::Tensor const& a, torch::stable::Tensor const& b_qweight, + torch::stable::Tensor const& b_scales, + std::optional const& b_qzeros, int64_t n, + int64_t group_size, int64_t sm_count, int64_t sm_version, int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) { - TORCH_CHECK_NOT_IMPLEMENTED( + STD_TORCH_CHECK_NOT_IMPLEMENTED( false, "allspark_w8a16_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); + return torch::stable::empty({1, 1}); } #else @@ -848,8 +856,8 @@ void restore_N32_K16_dequantize_rhs_w8a16(const QT* qdata, const FT* scales, const int N_32align, const int N, const int K, const int GroupSize, cudaStream_t stream) { - TORCH_CHECK(N % 8 == 0 && K % 16 == 0 && N_32align % 32 == 0, - "Unsupported shape"); + STD_TORCH_CHECK(N % 8 == 0 && K % 16 == 0 && N_32align % 32 == 0, + "Unsupported shape"); if (GroupSize == -1) { const int BLOCK = 128; dim3 grid(N_32align / 32, ((K / 16) + 3) / 4); @@ -859,7 +867,7 @@ void restore_N32_K16_dequantize_rhs_w8a16(const QT* qdata, const FT* scales, } // TODO: Support SubChannel else { - TORCH_CHECK(false, "Now only support PerChannel"); + STD_TORCH_CHECK(false, "Now only support PerChannel"); } } @@ -916,24 +924,27 @@ void allspark_qgemm_w8a16_perc_ampere( } // namespace allspark -torch::Tensor allspark_w8a16_gemm( - torch::Tensor const& a, torch::Tensor const& b_qweight, - torch::Tensor const& b_scales, std::optional const& b_qzeros, - int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, +torch::stable::Tensor allspark_w8a16_gemm( + torch::stable::Tensor const& a, torch::stable::Tensor const& b_qweight, + torch::stable::Tensor const& b_scales, + std::optional const& b_qzeros, int64_t n, + int64_t group_size, int64_t sm_count, int64_t sm_version, int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) { // Verify device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + STD_TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + STD_TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); - TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous"); + STD_TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); + STD_TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous"); - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + STD_TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + STD_TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); if (has_zp) { - TORCH_CHECK(b_qzeros.value().device().is_cuda(), "b_qzeros is not on GPU"); - TORCH_CHECK(b_qzeros.value().is_contiguous(), "b_qzeros is not contiguous"); + STD_TORCH_CHECK(b_qzeros.value().device().is_cuda(), + "b_qzeros is not on GPU"); + STD_TORCH_CHECK(b_qzeros.value().is_contiguous(), + "b_qzeros is not contiguous"); } int m = a.size(0); @@ -941,16 +952,17 @@ torch::Tensor allspark_w8a16_gemm( int k = a.size(1); // Verify shape - TORCH_CHECK(b_qweight.size(0) == n_32align, - "Shape mismatch: b_qweight.size(0) = ", b_qweight.size(0), - ", n_32align = ", n_32align); - TORCH_CHECK(b_qweight.size(1) == k, - "Shape mismatch: b_qweight.size(1) = ", b_qweight.size(1), - ", k = ", k); + STD_TORCH_CHECK(b_qweight.size(0) == n_32align, + "Shape mismatch: b_qweight.size(0) = ", b_qweight.size(0), + ", n_32align = ", n_32align); + STD_TORCH_CHECK(b_qweight.size(1) == k, + "Shape mismatch: b_qweight.size(1) = ", b_qweight.size(1), + ", k = ", k); - TORCH_CHECK(group_size == -1, "Currently only supports group_size = -1"); + STD_TORCH_CHECK(group_size == -1, "Currently only supports group_size = -1"); - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + const torch::stable::accelerator::DeviceGuard device_guard( + a.get_device_index()); const void* a_ptr = reinterpret_cast(a.data_ptr()); const uint8_t* b_ptr = reinterpret_cast(b_qweight.data_ptr()); const void* b_scale_ptr = reinterpret_cast(b_scales.data_ptr()); @@ -959,12 +971,12 @@ torch::Tensor allspark_w8a16_gemm( b_zero_ptr = reinterpret_cast(b_qzeros.value().data_ptr()); } - auto c_options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({m, n}, c_options); - void* c_ptr = reinterpret_cast(c.data_ptr()); + auto c = + torch::stable::empty({m, n}, a.scalar_type(), std::nullopt, a.device()); + void* c_ptr = reinterpret_cast(c.mutable_data_ptr()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cudaStream_t stream = get_current_cuda_stream(); + cublasHandle_t handle = get_current_cuda_blas_handle(); allspark::BlockTileSplitkParams fused_gemm_params; @@ -976,14 +988,15 @@ torch::Tensor allspark_w8a16_gemm( m, n, k, sm_count, fused_gemm_params); } - auto ws_options = torch::TensorOptions().dtype(at::kChar).device(a.device()); if (as_g_workspace.numel() < ws_size) { // ws_options: kChar, so numel() is bytes - as_g_workspace = torch::empty({long(ws_size)}, ws_options); + as_g_workspace = torch::stable::empty({static_cast(ws_size)}, + torch::headeronly::ScalarType::Char, + std::nullopt, a.device()); } void* ws = reinterpret_cast(as_g_workspace.data_ptr()); - if (a.dtype() == at::ScalarType::Half) { + if (a.scalar_type() == torch::headeronly::ScalarType::Half) { allspark::allspark_qgemm_w8a16_perc_ampere<__half, uint8_t>( reinterpret_cast(a_ptr), b_ptr, reinterpret_cast(b_scale_ptr), @@ -991,7 +1004,7 @@ torch::Tensor allspark_w8a16_gemm( reinterpret_cast<__half*>(c_ptr), m, n_32align, n, k, ws, fused_gemm_params, group_size, CUBLAS_M_THRESHOLD, sm_version, stream, handle); - } else if (a.dtype() == at::ScalarType::BFloat16) { + } else if (a.scalar_type() == torch::headeronly::ScalarType::BFloat16) { allspark::allspark_qgemm_w8a16_perc_ampere<__nv_bfloat16, uint8_t>( reinterpret_cast(a_ptr), b_ptr, reinterpret_cast(b_scale_ptr), @@ -1006,6 +1019,6 @@ torch::Tensor allspark_w8a16_gemm( #endif -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("allspark_w8a16_gemm", &allspark_w8a16_gemm); +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("allspark_w8a16_gemm", TORCH_BOX(&allspark_w8a16_gemm)); } diff --git a/csrc/quantization/gptq_allspark/allspark_repack.cu b/csrc/libtorch_stable/quantization/gptq_allspark/allspark_repack.cu similarity index 67% rename from csrc/quantization/gptq_allspark/allspark_repack.cu rename to csrc/libtorch_stable/quantization/gptq_allspark/allspark_repack.cu index 7a5b2f95cc2e..b325d30a041a 100644 --- a/csrc/quantization/gptq_allspark/allspark_repack.cu +++ b/csrc/libtorch_stable/quantization/gptq_allspark/allspark_repack.cu @@ -1,6 +1,11 @@ #include "allspark_utils.cuh" -#include + +#include +#include +#include + #include "core/registration.h" +#include "libtorch_stable/torch_utils.h" namespace allspark { @@ -99,36 +104,40 @@ void rearrange_kn_weight_as_n32k16_order_ldg16( } // namespace allspark void rearrange_kn_weight_as_n32k16_order( - torch::Tensor const& b_qweight, torch::Tensor const& b_scales, - std::optional const& b_zeros, bool has_zp, - torch::Tensor& b_qweight_reorder, torch::Tensor& b_scales_reorder, - std::optional const& b_zeros_reorder, const int64_t K, - const int64_t N, const int64_t N_32align) { + torch::stable::Tensor const& b_qweight, + torch::stable::Tensor const& b_scales, + std::optional const& b_zeros, bool has_zp, + torch::stable::Tensor& b_qweight_reorder, + torch::stable::Tensor& b_scales_reorder, + std::optional const& b_zeros_reorder, + const int64_t K, const int64_t N, const int64_t N_32align) { // Verify device and strides - TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); - TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous"); + STD_TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); + STD_TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous"); - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + STD_TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + STD_TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - TORCH_CHECK(b_qweight_reorder.device().is_cuda(), - "b_qweight_reorder is not on GPU"); - TORCH_CHECK(b_qweight_reorder.is_contiguous(), - "b_qweight_reorder is not contiguous"); + STD_TORCH_CHECK(b_qweight_reorder.device().is_cuda(), + "b_qweight_reorder is not on GPU"); + STD_TORCH_CHECK(b_qweight_reorder.is_contiguous(), + "b_qweight_reorder is not contiguous"); - TORCH_CHECK(b_scales_reorder.device().is_cuda(), - "b_scales_reorder is not on GPU"); - TORCH_CHECK(b_scales_reorder.is_contiguous(), - "b_scales_reorder is not contiguous"); + STD_TORCH_CHECK(b_scales_reorder.device().is_cuda(), + "b_scales_reorder is not on GPU"); + STD_TORCH_CHECK(b_scales_reorder.is_contiguous(), + "b_scales_reorder is not contiguous"); if (has_zp) { - TORCH_CHECK(b_zeros.value().device().is_cuda(), "b_zeros is not on GPU"); - TORCH_CHECK(b_zeros.value().is_contiguous(), "b_zeros is not contiguous"); - - TORCH_CHECK(b_zeros_reorder.value().device().is_cuda(), - "b_zeros_reorder is not on GPU"); - TORCH_CHECK(b_zeros_reorder.value().is_contiguous(), - "b_zeros_reorder is not contiguous"); + STD_TORCH_CHECK(b_zeros.value().device().is_cuda(), + "b_zeros is not on GPU"); + STD_TORCH_CHECK(b_zeros.value().is_contiguous(), + "b_zeros is not contiguous"); + + STD_TORCH_CHECK(b_zeros_reorder.value().device().is_cuda(), + "b_zeros_reorder is not on GPU"); + STD_TORCH_CHECK(b_zeros_reorder.value().is_contiguous(), + "b_zeros_reorder is not contiguous"); } const uint8_t* matB = reinterpret_cast(b_qweight.data_ptr()); @@ -136,18 +145,20 @@ void rearrange_kn_weight_as_n32k16_order( const void* b_zero = has_zp ? b_zeros.value().data_ptr() : nullptr; uint8_t* matB_reorder = - reinterpret_cast(b_qweight_reorder.data_ptr()); - void* b_scale_reorder = b_scales_reorder.data_ptr(); - void* b_zero_reorder = has_zp ? b_zeros_reorder.value().data_ptr() : nullptr; + reinterpret_cast(b_qweight_reorder.mutable_data_ptr()); + void* b_scale_reorder = b_scales_reorder.mutable_data_ptr(); + void* b_zero_reorder = + has_zp ? b_zeros_reorder.value().mutable_data_ptr() : nullptr; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (b_scales.dtype() == at::ScalarType::Half) { + cudaStream_t stream = get_current_cuda_stream(); + if (b_scales.scalar_type() == torch::headeronly::ScalarType::Half) { allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__half>( matB, reinterpret_cast(b_scale), reinterpret_cast(b_zero), matB_reorder, reinterpret_cast<__half*>(b_scale_reorder), reinterpret_cast<__half*>(b_zero_reorder), K, N, N_32align, stream); - } else if (b_scales.dtype() == at::ScalarType::BFloat16) { + } else if (b_scales.scalar_type() == + torch::headeronly::ScalarType::BFloat16) { allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__nv_bfloat16>( matB, reinterpret_cast(b_scale), reinterpret_cast(b_zero), matB_reorder, @@ -157,7 +168,7 @@ void rearrange_kn_weight_as_n32k16_order( } } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { m.impl("rearrange_kn_weight_as_n32k16_order", - &rearrange_kn_weight_as_n32k16_order); + TORCH_BOX(&rearrange_kn_weight_as_n32k16_order)); } diff --git a/csrc/quantization/gptq_allspark/allspark_utils.cuh b/csrc/libtorch_stable/quantization/gptq_allspark/allspark_utils.cuh similarity index 99% rename from csrc/quantization/gptq_allspark/allspark_utils.cuh rename to csrc/libtorch_stable/quantization/gptq_allspark/allspark_utils.cuh index c7a6e96aff4b..ce96c2d11fea 100644 --- a/csrc/quantization/gptq_allspark/allspark_utils.cuh +++ b/csrc/libtorch_stable/quantization/gptq_allspark/allspark_utils.cuh @@ -1,13 +1,12 @@ #pragma once -#include -#include -#include -#include -#include #include +#include +#include + #include -#include "../marlin/marlin_dtypes.cuh" + +#include "quantization/marlin/marlin_dtypes.cuh" using marlin::MarlinScalarType2; namespace allspark { diff --git a/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu b/csrc/libtorch_stable/quantization/hadamard/hadacore/hadamard_transform_cuda.cu similarity index 93% rename from csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu rename to csrc/libtorch_stable/quantization/hadamard/hadacore/hadamard_transform_cuda.cu index aff11326d78e..665585caa46c 100644 --- a/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu +++ b/csrc/libtorch_stable/quantization/hadamard/hadacore/hadamard_transform_cuda.cu @@ -11,18 +11,16 @@ Redistribution and use in source and binary forms, with or without modification, THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ***********/ -#include +#include "libtorch_stable/torch_utils.h" +#include "libtorch_stable/dispatch_utils.h" + +#include +#include + #include #include #include #include -#include - -#include -#include - -#include "core/registration.h" -#include "dispatch_utils.h" namespace hadacore { @@ -65,12 +63,12 @@ constexpr int launch_configs_big[7][3] = { }; // a 4x2, b 2x2, c 2x2 -template +template __device__ __forceinline__ void mma_m16_n8_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32& c0, b32& c1){ - static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16); + static_assert(dtype == torch::headeronly::ScalarType::Half || dtype == torch::headeronly::ScalarType::BFloat16); // d, a, b, c b32 zero = 0; - if constexpr(dtype == torch::ScalarType::Half) { + if constexpr(dtype == torch::headeronly::ScalarType::Half) { asm ( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n\t" @@ -89,7 +87,7 @@ __device__ __forceinline__ void mma_m16_n8_k16_b16_b16_b16_noacc(b32 a0, b32 a1, } // a 4x2, b 4x2, c 4x2 -template +template __device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32 b2, b32 b3, b32& c0, b32& c1, b32& c2, b32& c3){ mma_m16_n8_k16_b16_b16_b16_noacc(a0, a1, a2, a3, b0, b1, c0, c1); mma_m16_n8_k16_b16_b16_b16_noacc(a0, a1, a2, a3, b2, b3, c2, c3); @@ -108,11 +106,11 @@ __device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(b32& a0) { #define n_p(i) ((val_1n[i] & 0x0000FFFF) | val_1p[i] << 16) #define n_n(i) ((val_1n[i] & 0x0000FFFF) | val_1n[i] << 16) -template +template __global__ void __launch_bounds__(32 * warps_per_block, blocks_per_sm) // a is column major, b is row major hadamard_transform_kernel(b16* a, b16* out, int total_num_chunks) { - static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); + static_assert(dtype == torch::headeronly::ScalarType::Half || dtype == torch::headeronly::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); b32 b_frag_all[num_chunks][4]; // for all chunks, holds matrix fragment (which takes 4 regs of b16x2 * 32 threads) @@ -162,8 +160,8 @@ hadamard_transform_kernel(b16* a, b16* out, int total_num_chunks) { constexpr b16 bf16_1p[4] = {0b0011111100110101, 0b0011111100000000, 0b0011111010110101, 0b0011111010000000}; constexpr b16 bf16_1n[4] = {0b1011111100110101, 0b1011111100000000, 0b1011111010110101, 0b1011111010000000}; - #define val_type_1p(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1p[i]) : (bf16_1p[i])) - #define val_type_1n(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1n[i]) : (bf16_1n[i])) + #define val_type_1p(i) (((dtype) == torch::headeronly::ScalarType::Half) ? (fp16_1p[i]) : (bf16_1p[i])) + #define val_type_1n(i) (((dtype) == torch::headeronly::ScalarType::Half) ? (fp16_1n[i]) : (bf16_1n[i])) constexpr b16 val_1p[4] = {val_type_1p(0), val_type_1p(1), val_type_1p(2), val_type_1p(3)}; constexpr b16 val_1n[4] = {val_type_1n(0), val_type_1n(1), val_type_1n(2), val_type_1n(3)}; @@ -684,14 +682,14 @@ constexpr int64_t ceil_div(int64_t a, int64_t b) { return (a + b - 1) / b; } -template +template void __forceinline__ run_kernel(b16* a_mat, b16* out, int64_t num_chunks, cudaStream_t stream) { int64_t shared_size = chunks_per_warp * warps_per_block * 128 * 4; dim3 block_size = 32 * warps_per_block; #define CHECK_SHARED_LIM() { \ if (shared_size > 48 * 1024) { \ - C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536)); \ + STD_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536)); \ } \ } \ @@ -714,10 +712,10 @@ void __forceinline__ run_kernel(b16* a_mat, b16* out, int64_t num_chunks, cudaSt kernel<<>>(a_mat, out, num_chunks); } - C10_CUDA_KERNEL_LAUNCH_CHECK(); + STD_CUDA_KERNEL_LAUNCH_CHECK(); } -template +template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream) { int64_t num_chunks = numel / 256; // caller required to ensure divisible by 256 // for size 256, use (2, 1) @@ -764,54 +762,54 @@ void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cu } } -template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream); -template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream); +template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream); +template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream); } // namespace hadacore constexpr bool is_power_of_two(int x) { return x && !(x & (x - 1)); } -torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace) { +torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x, bool inplace) { auto dtype = x.scalar_type(); - TORCH_CHECK(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); - TORCH_CHECK(x.is_cuda()); - + STD_TORCH_CHECK(dtype == torch::headeronly::ScalarType::Half || dtype == torch::headeronly::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); + STD_TORCH_CHECK(x.is_cuda()); + const int had_size = x.size(-1); - TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)), + STD_TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)), "Only power of two Hadamard sizes up to 2^15 are supported, got ", had_size); - + const auto res_shape = x.sizes(); - x = x.reshape({-1, had_size}); - + x = torch::stable::reshape(x, {-1, had_size}); + auto numel = x.numel(); if (numel % 256 != 0) { - x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 0, 0, (256 - numel % 256) / had_size})); + x = torch::stable::pad(x, {0, 0, 0, (256 - numel % 256) / had_size}); } - + if (x.stride(-1) != 1) { - x = x.contiguous(); + x = torch::stable::contiguous(x); } - torch::Tensor out = inplace ? x : torch::empty_like(x); + torch::stable::Tensor out = inplace ? x : torch::stable::empty_like(x); - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); + torch::stable::accelerator::DeviceGuard device_guard(x.get_device_index()); + auto stream = get_current_cuda_stream(); - VLLM_DISPATCH_HALF_TYPES(x.scalar_type(), "hadacore_transform_runfht", [&] { - auto constexpr SCALAR_TYPE = c10::CppTypeToScalarType::value; + VLLM_STABLE_DISPATCH_HALF_TYPES(x.scalar_type(), "hadacore_transform_runfht", [&] { + auto constexpr SCALAR_TYPE = torch::headeronly::CppTypeToScalarType::value; hadacore::run_fht(x.data_ptr(), x.data_ptr(), x.numel(), had_size, stream); }); if (numel % 256 != 0) { - out = out.narrow(0, 0, numel / had_size); + out = torch::stable::narrow(out, 0, 0, numel / had_size); } if (inplace && out.data_ptr() != x.data_ptr()) { - x.copy_(out.view(res_shape)); + torch::stable::copy_(x, torch::stable::view(out, res_shape)); return x; } - return out.reshape(res_shape); + return torch::stable::reshape(out, res_shape); } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("hadacore_transform", &hadacore_transform); +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("hadacore_transform", TORCH_BOX(&hadacore_transform)); } diff --git a/csrc/quantization/w8a8/fp8/common.cu b/csrc/libtorch_stable/quantization/w8a8/fp8/common.cu similarity index 66% rename from csrc/quantization/w8a8/fp8/common.cu rename to csrc/libtorch_stable/quantization/w8a8/fp8/common.cu index 52e159d65010..d02fc2296e61 100644 --- a/csrc/quantization/w8a8/fp8/common.cu +++ b/csrc/libtorch_stable/quantization/w8a8/fp8/common.cu @@ -1,11 +1,9 @@ -#include "common.cuh" -#include "dispatch_utils.h" -#include "cub_helpers.h" -#include "libtorch_stable/quantization/vectorization_utils.cuh" -#include -#include -#include - +#include "../../../../quantization/w8a8/fp8/common.cuh" +#include "../../../dispatch_utils.h" +#include "../../../../cub_helpers.h" +#include "../../vectorization_utils.cuh" +#include "../../../torch_utils.h" +#include namespace vllm { // STRIDE_I_ZERO: true if scale_stride_i == 0 (per-tensor or per-channel) @@ -183,16 +181,16 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( } // namespace vllm void static_scaled_fp8_quant( - torch::Tensor& out, // [..., d] - torch::Tensor const& input, // [..., d] - torch::Tensor const& scale, // various shapes - std::optional> - opt_group_shape) // optional explicit (group_m, group_n) + torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor const& input, // [..., d] + torch::stable::Tensor const& scale, // various shapes + std::optional + opt_group_shape) // optional explicit [group_m, group_n] { - TORCH_CHECK(input.stride(-1) == 1, - "last dimension of input must be contiguous"); - TORCH_CHECK(out.stride(-1) == 1, - "last dimension of output must be contiguous"); + STD_TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + STD_TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); const int hidden_size = input.size(-1); // N (columns) const int num_tokens = input.numel() / hidden_size; // M (rows) @@ -212,13 +210,18 @@ void static_scaled_fp8_quant( } else if (scale.dim() == 1) { // 1D scale: require explicit group_shape to disambiguate per-channel vs // per-token (avoids edge case where num_tokens == hidden_size) - TORCH_CHECK(opt_group_shape.has_value(), - "1D scale requires explicit group_shape to disambiguate " - "per-channel vs per-token quantization. " - "Use group_shape=(-1, 1) for per-channel or group_shape=(1, " - "-1) for per-token."); - - const auto& [opt_group_m, opt_group_n] = opt_group_shape.value(); + STD_TORCH_CHECK( + opt_group_shape.has_value(), + "1D scale requires explicit group_shape to disambiguate " + "per-channel vs per-token quantization. " + "Use group_shape=(-1, 1) for per-channel or group_shape=(1, " + "-1) for per-token."); + STD_TORCH_CHECK(opt_group_shape->size() == 2, + "group_shape must have exactly 2 elements, got ", + opt_group_shape->size()); + + const auto opt_group_m = (*opt_group_shape)[0]; + const auto opt_group_n = (*opt_group_shape)[1]; group_m = opt_group_m == -1 ? num_tokens : static_cast(opt_group_m); group_n = opt_group_n == -1 ? hidden_size : static_cast(opt_group_n); @@ -228,11 +231,11 @@ void static_scaled_fp8_quant( const int64_t expected_scale_n = hidden_size / group_n; const int64_t expected_scale_numel = expected_scale_m * expected_scale_n; - TORCH_CHECK(scale_len == expected_scale_numel, "1D scale length (", - scale_len, ") does not match expected size (", - expected_scale_numel, ") for group_shape (", opt_group_m, ", ", - opt_group_n, ") with input shape (", num_tokens, ", ", - hidden_size, ")"); + STD_TORCH_CHECK(scale_len == expected_scale_numel, "1D scale length (", + scale_len, ") does not match expected size (", + expected_scale_numel, ") for group_shape (", opt_group_m, + ", ", opt_group_n, ") with input shape (", num_tokens, ", ", + hidden_size, ")"); // For 1D scale, determine strides based on which dim is trivial // Scale indexing: scale[gi * scale_stride_i + gj * scale_stride_j] @@ -248,7 +251,7 @@ void static_scaled_fp8_quant( scale_stride_i = scale.stride(0); scale_stride_j = 0; } else { - TORCH_CHECK( + STD_TORCH_CHECK( false, "1D scale can only be used when one of the scale dimensions is 1. " "For 2D group scaling, use a 2D scale tensor."); @@ -259,10 +262,12 @@ void static_scaled_fp8_quant( const int64_t scale_size_0 = scale.size(0); const int64_t scale_size_1 = scale.size(1); - TORCH_CHECK(num_tokens % scale_size_0 == 0, "num_tokens (", num_tokens, - ") must be divisible by scale.size(0) (", scale_size_0, ")"); - TORCH_CHECK(hidden_size % scale_size_1 == 0, "hidden_size (", hidden_size, - ") must be divisible by scale.size(1) (", scale_size_1, ")"); + STD_TORCH_CHECK(num_tokens % scale_size_0 == 0, "num_tokens (", num_tokens, + ") must be divisible by scale.size(0) (", scale_size_0, + ")"); + STD_TORCH_CHECK(hidden_size % scale_size_1 == 0, "hidden_size (", + hidden_size, ") must be divisible by scale.size(1) (", + scale_size_1, ")"); // Infer from 2D scale shape int inferred_group_m = num_tokens / scale_size_0; @@ -270,16 +275,21 @@ void static_scaled_fp8_quant( // Use explicit if provided, otherwise use inferred if (opt_group_shape.has_value()) { - const auto& [opt_group_m, opt_group_n] = opt_group_shape.value(); + STD_TORCH_CHECK(opt_group_shape->size() == 2, + "group_shape must have exactly 2 elements, got ", + opt_group_shape->size()); + const auto opt_group_m = (*opt_group_shape)[0]; + const auto opt_group_n = (*opt_group_shape)[1]; group_m = opt_group_m == -1 ? num_tokens : static_cast(opt_group_m); group_n = opt_group_n == -1 ? hidden_size : static_cast(opt_group_n); // Validate explicit matches inferred - TORCH_CHECK(group_m == inferred_group_m && group_n == inferred_group_n, - "Explicit group_shape (", opt_group_m, ", ", opt_group_n, - ") does not match inferred group shape (", inferred_group_m, - ", ", inferred_group_n, ") from 2D scale tensor shape (", - scale_size_0, ", ", scale_size_1, ")"); + STD_TORCH_CHECK( + group_m == inferred_group_m && group_n == inferred_group_n, + "Explicit group_shape (", opt_group_m, ", ", opt_group_n, + ") does not match inferred group shape (", inferred_group_m, ", ", + inferred_group_n, ") from 2D scale tensor shape (", scale_size_0, + ", ", scale_size_1, ")"); } else { group_m = inferred_group_m; group_n = inferred_group_n; @@ -288,8 +298,8 @@ void static_scaled_fp8_quant( scale_stride_i = scale.stride(0); scale_stride_j = scale.stride(1); } else { - TORCH_CHECK(false, "scale must be 0D, 1D, or 2D tensor, but got ", - scale.dim(), "D"); + STD_TORCH_CHECK(false, "scale must be 0D, 1D, or 2D tensor, but got ", + scale.dim(), "D"); } const int block_size = 256; @@ -299,37 +309,39 @@ void static_scaled_fp8_quant( const int64_t in_row_stride = input.stride(-2); const int64_t out_row_stride = out.stride(-2); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); // Dispatch to template-specialized kernel based on stride pattern - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { - VLLM_DISPATCH_FP8_TYPES( + VLLM_STABLE_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { - VLLM_DISPATCH_BOOL(scale_stride_i == 0, S0_ZERO, [&] { - VLLM_DISPATCH_BOOL(scale_stride_j == 0, S1_ZERO, [&] { + VLLM_STABLE_DISPATCH_BOOL(scale_stride_i == 0, S0_ZERO, [&] { + VLLM_STABLE_DISPATCH_BOOL(scale_stride_j == 0, S1_ZERO, [&] { vllm::scaled_fp8_quant_kernel_strided_group_shape< scalar_t, fp8_t, S0_ZERO, S1_ZERO> <<>>( - out.data_ptr(), input.data_ptr(), - scale.data_ptr(), hidden_size, in_row_stride, - out_row_stride, group_m, group_n, scale_stride_i, - scale_stride_j); + out.mutable_data_ptr(), + input.const_data_ptr(), + scale.const_data_ptr(), hidden_size, + in_row_stride, out_row_stride, group_m, group_n, + scale_stride_i, scale_stride_j); }); }); }); }); } -void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] - torch::Tensor const& input, // [..., d] - torch::Tensor& scale) // [1] +void dynamic_scaled_fp8_quant(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor const& input, // [..., d] + torch::stable::Tensor& scale) // [1] { - TORCH_CHECK(input.stride(-1) == 1, - "last dimension of input must be contiguous"); - TORCH_CHECK(out.stride(-1) == 1, - "last dimension of output must be contiguous"); + STD_TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + STD_TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; @@ -340,40 +352,43 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] const int64_t in_row_stride = input.stride(-2); const int64_t out_row_stride = out.stride(-2); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); // scale tensor should be initialised to <=0 before reduction - AT_CUDA_CHECK( - cudaMemsetAsync(scale.data_ptr(), 0, sizeof(float), stream)); + STD_CUDA_CHECK(cudaMemsetAsync(scale.mutable_data_ptr(), 0, + sizeof(float), stream)); - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { - VLLM_DISPATCH_FP8_TYPES( + VLLM_STABLE_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { vllm::segmented_max_reduction_strided <<>>( - scale.data_ptr(), input.data_ptr(), - hidden_size, in_row_stride, - static_cast(num_tokens)); + scale.mutable_data_ptr(), + input.const_data_ptr(), hidden_size, + in_row_stride, static_cast(num_tokens)); vllm::scaled_fp8_quant_kernel_strided_dynamic - <<>>( - out.data_ptr(), input.data_ptr(), - scale.data_ptr(), hidden_size, in_row_stride, - out_row_stride); + <<>>(out.mutable_data_ptr(), + input.const_data_ptr(), + scale.const_data_ptr(), + hidden_size, in_row_stride, + out_row_stride); }); }); } void dynamic_per_token_scaled_fp8_quant( - torch::Tensor& out, // [..., d] - torch::Tensor const& input, // [..., d] - torch::Tensor& scales, std::optional const& scale_ub) { - TORCH_CHECK(input.stride(-1) == 1, - "last dimension of input must be contiguous"); - TORCH_CHECK(out.stride(-1) == 1, - "last dimension of output must be contiguous"); + torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor const& input, // [..., d] + torch::stable::Tensor& scales, + std::optional const& scale_ub) { + STD_TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + STD_TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; @@ -384,20 +399,24 @@ void dynamic_per_token_scaled_fp8_quant( const int64_t in_row_stride = input.stride(-2); const int64_t out_row_stride = out.stride(-2); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] { - VLLM_DISPATCH_FP8_TYPES( + VLLM_STABLE_DISPATCH_FP8_TYPES( out.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] { - vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided< - scalar_t, fp8_t><<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - hidden_size, in_row_stride, out_row_stride); + vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided + <<>>( + out.mutable_data_ptr(), + scales.mutable_data_ptr(), + input.const_data_ptr(), + scale_ub.has_value() ? scale_ub->const_data_ptr() + : nullptr, + hidden_size, in_row_stride, out_row_stride); }); }); } diff --git a/csrc/quantization/w8a8/int8/scaled_quant.cu b/csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu similarity index 79% rename from csrc/quantization/w8a8/int8/scaled_quant.cu rename to csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu index ae1395a363c7..ede7913a3558 100644 --- a/csrc/quantization/w8a8/int8/scaled_quant.cu +++ b/csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu @@ -1,12 +1,11 @@ -#include -#include -#include +#include #include -#include "dispatch_utils.h" -#include "libtorch_stable/quantization/vectorization_utils.cuh" -#include "cub_helpers.h" +#include "../../../dispatch_utils.h" +#include "../../../torch_utils.h" +#include "../../vectorization_utils.cuh" +#include "../../../../cub_helpers.h" static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM @@ -263,66 +262,73 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( } // namespace vllm -void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] - torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& scale, - std::optional const& azp) { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(scale.numel() == 1); - TORCH_CHECK(!azp || azp->numel() == 1); +void static_scaled_int8_quant( + torch::stable::Tensor& out, // [..., hidden_size] + torch::stable::Tensor const& input, // [..., hidden_size] + torch::stable::Tensor const& scale, + std::optional const& azp) { + STD_TORCH_CHECK(input.is_contiguous()); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(scale.numel() == 1); + STD_TORCH_CHECK(!azp || azp->numel() == 1); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 256)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { if (!azp) { vllm::static_scaled_int8_quant_kernel - <<>>( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), hidden_size); + <<>>(input.const_data_ptr(), + out.mutable_data_ptr(), + scale.const_data_ptr(), + hidden_size); } else { vllm::static_scaled_int8_azp_quant_kernel <<>>( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), azp->data_ptr(), - hidden_size); + input.const_data_ptr(), + out.mutable_data_ptr(), scale.const_data_ptr(), + azp->const_data_ptr(), hidden_size); } }); } void dynamic_scaled_int8_quant( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales, std::optional const& azp) { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(scales.is_contiguous()); - TORCH_CHECK(!azp || azp->is_contiguous()); + torch::stable::Tensor& out, // [..., hidden_size] + torch::stable::Tensor const& input, // [..., hidden_size] + torch::stable::Tensor& scales, + std::optional const& azp) { + STD_TORCH_CHECK(input.is_contiguous()); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(scales.is_contiguous()); + STD_TORCH_CHECK(!azp || azp->is_contiguous()); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 256)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { if (!azp) { vllm::dynamic_scaled_int8_quant_kernel - <<>>( - input.data_ptr(), out.data_ptr(), - scales.data_ptr(), hidden_size); + <<>>(input.const_data_ptr(), + out.mutable_data_ptr(), + scales.mutable_data_ptr(), + hidden_size); } else { vllm::dynamic_scaled_int8_azp_quant_kernel - <<>>( - input.data_ptr(), out.data_ptr(), - scales.data_ptr(), azp->data_ptr(), - hidden_size); + <<>>(input.const_data_ptr(), + out.mutable_data_ptr(), + scales.mutable_data_ptr(), + azp->mutable_data_ptr(), + hidden_size); } }); } \ No newline at end of file diff --git a/csrc/libtorch_stable/torch_bindings.cpp b/csrc/libtorch_stable/torch_bindings.cpp index c31844948e5f..304cbde03891 100644 --- a/csrc/libtorch_stable/torch_bindings.cpp +++ b/csrc/libtorch_stable/torch_bindings.cpp @@ -199,7 +199,152 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { ops.def( "cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, " "Tensor)"); + + // SM100 CUTLASS MLA decode + // conditionally compiled so impl registrations are in source file + ops.def( + "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope," + " Tensor q_pe, Tensor kv_c_and_k_pe_cache," + " Tensor seq_lens, Tensor page_table," + " Tensor workspace, float scale," + " int num_kv_splits) -> ()"); + + ops.def( + "sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches," + " int sm_count, int num_kv_splits) " + "-> int"); + // Quantized GEMM for AWQ. + ops.def( + "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " + "Tensor _zeros, SymInt split_k_iters) -> Tensor"); + + // Dequantization for AWQ. + ops.def( + "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, " + "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor"); + + // DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens). + // conditionally compiled so impl registration is in source file + ops.def( + "dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); + + // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel + ops.def( + "rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, " + "Tensor? b_zeros, " + "bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, " + "Tensor!? b_zeros_reorder, " + "int K, int N, int N_32align) -> ()"); + + // AllSpark quantization ops + ops.def( + "allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, " + "Tensor? b_qzeros, " + "SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt " + "CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor"); #endif + + // Hadamard transforms + // conditionally compiled so impl registration is in source file + ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor"); + + // Activation ops + // Activation function used in SwiGLU. + ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); + + // Activation function used in GeGLU with `none` approximation. + ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + + // Activation function used in GeGLU with `tanh` approximation. + ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + + // FATReLU implementation. + ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); + + ops.def( + "swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float " + "limit=7.0) " + "-> ()"); + + // GELU implementation used in GPT-2. + ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); + + // Approximate GELU implementation. + ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); + + // Quick GELU implementation. + ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); + + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); + + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); + + // Compute FP8 quantized tensor for given scaling factor. + // Supports per-tensor, per-channel, per-token, and arbitrary 2D group + // scaling. Optional group_m/group_n specify the group shape explicitly; + // required for 1D scales to disambiguate per-channel vs per-token. + ops.def( + "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, " + "int[]? group_shape=None) -> ()"); + + // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. + ops.def( + "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " + "-> " + "()"); + + // Compute dynamic-per-token FP8 quantized tensor and scaling factor. + ops.def( + "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, " + "Tensor! scale, Tensor? scale_ub) -> " + "()"); + + // Quantized GEMM for GPTQ. + // Note: even though the C++ inferred schema is correct for this op, it seems + // to prevent the meta function registry. + ops.def( + "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, " + "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool " + "use_v2_format, int bit) " + "-> Tensor"); + + // Post processing for GPTQ. + ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); + + // Dequantization for GGML. + ops.def( + "ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? " + "dtype) -> Tensor"); + + // mmvq kernel for GGML. + ops.def( + "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) " + "-> Tensor"); + + // mmq kernel for GGML. + ops.def( + "ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"); + + // moe kernel for GGML. + ops.def( + "ggml_moe_a8(Tensor X, Tensor W, " + "Tensor sorted_token_ids, Tensor expert_ids, Tensor " + "num_tokens_post_padded, " + "int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"); + + ops.def( + "ggml_moe_a8_vec(Tensor X, Tensor W, " + "Tensor topk_ids, int top_k, " + "int type, SymInt row, SymInt tokens) -> Tensor"); + + ops.def("ggml_moe_get_block_size(int type) -> int"); } STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { @@ -236,7 +381,49 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { // W4A8 ops: impl registrations are in the source files // (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu) + + // AWQ ops + ops.impl("awq_gemm", TORCH_BOX(&awq_gemm)); + ops.impl("awq_dequantize", TORCH_BOX(&awq_dequantize)); + + // DSV3 fused A GEMM: conditionally compiled so impl registration is in + // source file (dsv3_fused_a_gemm.cu) + + // AllSpark ops: conditionally compiled so impl registrations are in source + // files (allspark_repack.cu and allspark_qgemm_w8a16.cu) #endif + + // Activation kernels (shared CUDA/ROCm) + ops.impl("silu_and_mul", TORCH_BOX(&silu_and_mul)); + ops.impl("mul_and_silu", TORCH_BOX(&mul_and_silu)); + ops.impl("gelu_and_mul", TORCH_BOX(&gelu_and_mul)); + ops.impl("gelu_tanh_and_mul", TORCH_BOX(&gelu_tanh_and_mul)); + ops.impl("fatrelu_and_mul", TORCH_BOX(&fatrelu_and_mul)); + ops.impl("swigluoai_and_mul", TORCH_BOX(&swigluoai_and_mul)); + ops.impl("gelu_new", TORCH_BOX(&gelu_new)); + ops.impl("gelu_fast", TORCH_BOX(&gelu_fast)); + ops.impl("gelu_quick", TORCH_BOX(&gelu_quick)); + + // INT8 quantization kernels + ops.impl("static_scaled_int8_quant", TORCH_BOX(&static_scaled_int8_quant)); + ops.impl("dynamic_scaled_int8_quant", TORCH_BOX(&dynamic_scaled_int8_quant)); + + // FP8 quantization kernels + ops.impl("static_scaled_fp8_quant", TORCH_BOX(&static_scaled_fp8_quant)); + ops.impl("dynamic_scaled_fp8_quant", TORCH_BOX(&dynamic_scaled_fp8_quant)); + ops.impl("dynamic_per_token_scaled_fp8_quant", + TORCH_BOX(&dynamic_per_token_scaled_fp8_quant)); + + // GPTQ kernels + ops.impl("gptq_gemm", TORCH_BOX(&gptq_gemm)); + ops.impl("gptq_shuffle", TORCH_BOX(&gptq_shuffle)); + + // GGML kernels + ops.impl("ggml_dequantize", TORCH_BOX(&ggml_dequantize)); + ops.impl("ggml_mul_mat_vec_a8", TORCH_BOX(&ggml_mul_mat_vec_a8)); + ops.impl("ggml_mul_mat_a8", TORCH_BOX(&ggml_mul_mat_a8)); + ops.impl("ggml_moe_a8", TORCH_BOX(&ggml_moe_a8)); + ops.impl("ggml_moe_a8_vec", TORCH_BOX(&ggml_moe_a8_vec)); } // These capability-check functions take only primitive args (no tensors), so @@ -254,6 +441,9 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) { ops.impl("cutlass_scaled_mm_supports_fp4", TORCH_BOX(&cutlass_scaled_mm_supports_fp4)); #endif + + // GGML block size lookup (no tensor args) + ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size)); } REGISTER_EXTENSION(_C_stable_libtorch) diff --git a/csrc/libtorch_stable/torch_utils.h b/csrc/libtorch_stable/torch_utils.h index f5a80d63e1e7..cf528f26e16b 100644 --- a/csrc/libtorch_stable/torch_utils.h +++ b/csrc/libtorch_stable/torch_utils.h @@ -8,10 +8,70 @@ #include +#include + +#include +#include +#include +#include + // Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED. #define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \ STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__) +// Device properties cache for stable ABI compatibility. +// Uses raw CUDA/HIP APIs instead of ATen functions. +// Using inline ensures a single instance across all translation units. +inline std::deque device_flags; +inline std::vector device_properties; +inline std::once_flag vectors_init_flag; + +inline void do_init_device_vectors() { + int device_count; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDeviceCount failed: " + + std::string(cudaGetErrorString(err))); + } + device_flags.resize(device_count); + device_properties.resize(device_count); +} + +inline void initDeviceVectors() { + std::call_once(vectors_init_flag, do_init_device_vectors); +} + +inline void initDeviceProperty(int device_index) { + cudaDeviceProp device_prop{}; + cudaError_t err = cudaGetDeviceProperties(&device_prop, device_index); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + + std::string(cudaGetErrorString(err))); + } + device_properties[device_index] = device_prop; +} + +// Get device properties using raw CUDA/HIP APIs (stable ABI compatible). +// Caches results per device so cudaGetDeviceProperties is called at most once +// per device. +inline cudaDeviceProp* get_device_prop() { + initDeviceVectors(); + int device_index; + cudaError_t err = cudaGetDevice(&device_index); + if (err != cudaSuccess) { + STD_TORCH_CHECK( + false, "cudaGetDevice failed: " + std::string(cudaGetErrorString(err))); + } + STD_TORCH_CHECK(device_index >= 0 && static_cast(device_index) < + device_properties.size(), + "CUDA device index " + std::to_string(device_index) + + " out of range [0, " + + std::to_string(device_properties.size()) + ")"); + + std::call_once(device_flags[device_index], initDeviceProperty, device_index); + return &device_properties[device_index]; +} + // Utility to get the current CUDA stream for a given device using stable APIs. // Returns a cudaStream_t for use in kernel launches. inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) { @@ -20,3 +80,10 @@ inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) { aoti_torch_get_current_cuda_stream(device_index, &stream_ptr)); return reinterpret_cast(stream_ptr); } + +// Utility to get the current cuBLAS handle using stable APIs. +inline cublasHandle_t get_current_cuda_blas_handle() { + void* blas_handle_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(torch_get_current_cuda_blas_handle(&blas_handle_ptr)); + return reinterpret_cast(blas_handle_ptr); +} diff --git a/csrc/ops.h b/csrc/ops.h index 20351a3e4dc0..8d259d549ce6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -166,17 +166,10 @@ void persistent_masked_m_silu_mul_quant( at::Tensor& y_s, // (E, T, H//group_size) [OUT] bool use_ue8m0); -void mul_and_silu(torch::Tensor& out, torch::Tensor& input); - void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); -void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input, - double threshold); -void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input, - double alpha = 1.702, double limit = 7.0); - void gelu_new(torch::Tensor& out, torch::Tensor& input); void gelu_fast(torch::Tensor& out, torch::Tensor& input); @@ -191,41 +184,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); -#ifndef USE_ROCM - -torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, - int64_t split_k_iters); - -torch::Tensor awq_dequantize(torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, - int64_t thx, int64_t thy); - -#endif - -torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, - int64_t n, - std::optional const& dtype); - -torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, - int64_t type, int64_t row); - -torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, - int64_t row); - -torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W, - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_padded, int64_t type, - int64_t row, int64_t top_k, int64_t tokens); - -torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W, - torch::Tensor topk_ids, int64_t top_k, - int64_t type, int64_t row, int64_t tokens); - -int64_t ggml_moe_get_block_size(int64_t type); - void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale, std::optional const& azp); @@ -234,24 +192,6 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, std::optional const& azp); -torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, bool use_v2_format, int64_t bit); - -void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); - -void static_scaled_fp8_quant( - torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale, - std::optional> group_shape = std::nullopt); - -void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor& scale); - -void dynamic_per_token_scaled_fp8_quant( - torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, - std::optional const& scale_ub); - void selective_scan_fwd( const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& C, @@ -293,8 +233,6 @@ std::tuple allocate_shared_buffer_and_handle( int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); -torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace); - #ifdef USE_ROCM fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional qr_max_size = std::nullopt); @@ -305,8 +243,3 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); int64_t qr_max_size(); #endif - -#ifndef USE_ROCM -void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, - torch::Tensor const& mat_b); -#endif \ No newline at end of file diff --git a/csrc/quantization/marlin/marlin.cuh b/csrc/quantization/marlin/marlin.cuh index 33fe52f605b4..d3a91568349f 100644 --- a/csrc/quantization/marlin/marlin.cuh +++ b/csrc/quantization/marlin/marlin.cuh @@ -2,10 +2,14 @@ #ifndef _marlin_cuh #define _marlin_cuh - #include - - #include - #include + // These torch headers are only needed by non-stable callers (e.g. ops.cu). + // Guard them so that stable ABI targets can still include marlin.cuh + // for Vec, constants, and cp_async helpers without pulling in torch/all.h. + #ifndef TORCH_TARGET_VERSION + #include + #include + #include + #endif #include #include #include diff --git a/csrc/quantization/utils.cuh b/csrc/quantization/utils.cuh index 73055a152874..6bb9b9fc5635 100644 --- a/csrc/quantization/utils.cuh +++ b/csrc/quantization/utils.cuh @@ -7,23 +7,23 @@ */ #include -#include +#include #ifndef USE_ROCM - #include + #include #define MAYBE_HOST_DEVICE C10_HOST_DEVICE #else - #include - #include - #include + #include + #include // ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr #define MAYBE_HOST_DEVICE #endif template || - std::is_same_v || - std::is_same_v>> + typename = std::enable_if_t< + std::is_same_v || + std::is_same_v || + std::is_same_v>> struct quant_type_max { static constexpr T val() { return std::numeric_limits::max(); } }; @@ -31,9 +31,10 @@ struct quant_type_max { // Using the default max value from pytorch (240.0 0x7F) will cause accuracy // issues when running dynamic quantization. Here use 224.0 0x7E for rocm. template <> -struct quant_type_max { - static constexpr c10::Float8_e4m3fnuz val() { - return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits()); +struct quant_type_max { + static constexpr torch::headeronly::Float8_e4m3fnuz val() { + return torch::headeronly::Float8_e4m3fnuz( + 0x7E, torch::headeronly::Float8_e4m3fnuz::from_bits()); } }; @@ -42,9 +43,10 @@ MAYBE_HOST_DEVICE static constexpr T quant_type_max_v = quant_type_max::val(); template || - std::is_same_v || - std::is_same_v>> + typename = std::enable_if_t< + std::is_same_v || + std::is_same_v || + std::is_same_v>> struct min_scaling_factor { C10_DEVICE C10_ALWAYS_INLINE static float val() { return 1.0f / (quant_type_max_v * 512.0f); diff --git a/csrc/quantization/w8a8/fp8/common.cuh b/csrc/quantization/w8a8/fp8/common.cuh index 7a385f5163ae..e8eb04289836 100644 --- a/csrc/quantization/w8a8/fp8/common.cuh +++ b/csrc/quantization/w8a8/fp8/common.cuh @@ -5,6 +5,19 @@ #include +// This header is shared between _C and _C_stable_libtorch targets. +// torch_utils.h provides get_device_prop(). We need to pass USE_CUDA +// to the .so to expose some of the shims used by torch_utils.h. For now +// this is only done for _C_stable_libtorch and not for _C, so we use the +// non stable at::cuda::getCurrentDeviceProperties for _C for now. +#ifdef TORCH_TARGET_VERSION + #include "../../../libtorch_stable/torch_utils.h" +#else + #ifdef USE_ROCM + #include + #endif +#endif + #ifndef USE_ROCM #include "nvidia/quant_utils.cuh" #else @@ -18,7 +31,11 @@ static bool is_fp8_ocp() { #ifndef USE_ROCM return true; #else - auto dprops = at::cuda::getCurrentDeviceProperties(); + #ifdef TORCH_TARGET_VERSION + auto* dprops = get_device_prop(); + #else + auto* dprops = at::cuda::getCurrentDeviceProperties(); + #endif std::string device_arch = dprops->gcnArchName; size_t substring = device_arch.find("gfx94"); return substring == std::string::npos; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 0354df666c3a..2a021a7e2515 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -100,48 +100,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { &convert_vertical_slash_indexes_mergehead); #endif - // Activation ops - // Activation function used in SwiGLU. - ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); - ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); - + // Activation ops (quantized only — basic ops moved to _C_stable_libtorch) ops.def( "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); - ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); - ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); - - // Activation function used in GeGLU with `none` approximation. - ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); - - // Activation function used in GeGLU with `tanh` approximation. - ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); - - // FATReLU implementation. - ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); - ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul); - - ops.def( - "swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float " - "limit=7.0) " - "-> ()"); - ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul); - - // GELU implementation used in GPT-2. - ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_new", torch::kCUDA, &gelu_new); - - // Approximate GELU implementation. - ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); - - // Quick GELU implementation. - ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); - // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( @@ -243,22 +206,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "bool is_scale_transposed=False) -> ()"); ops.impl("silu_and_mul_per_block_quant", torch::kCUDA, &silu_and_mul_per_block_quant); - // DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens). - ops.def( - "dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); - // conditionally compiled so impl registration is in source file - - // Quantized GEMM for AWQ. - ops.def( - "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, SymInt split_k_iters) -> Tensor"); - ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); - - // Dequantization for AWQ. - ops.def( - "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor"); - ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); // Note about marlin kernel 'workspace' arguments: // Technically these should be mutable since they are modified by the kernel. @@ -338,39 +285,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #endif - // Dequantization for GGML. - ops.def( - "ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? " - "dtype) -> Tensor"); - ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize); - - // mmvq kernel for GGML. - ops.def( - "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) " - "-> Tensor"); - ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8); - - // mmq kernel for GGML. - ops.def( - "ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"); - ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8); - - // moe kernel for GGML. - ops.def( - "ggml_moe_a8(Tensor X, Tensor W, " - "Tensor sorted_token_ids, Tensor expert_ids, Tensor " - "num_tokens_post_padded, " - "int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"); - ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8); - - ops.def( - "ggml_moe_a8_vec(Tensor X, Tensor W, " - "Tensor topk_ids, int top_k, " - "int type, SymInt row, SymInt tokens) -> Tensor"); - ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec); - - ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); - #ifndef USE_ROCM // Expert-specialization mxfp8 blockscaled grouped quantization (SM100+). ops.def( @@ -388,75 +302,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " -> ()"); // conditionally compiled so impl registration is in source file - // SM100 CUTLASS MLA decode - ops.def( - "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope," - " Tensor q_pe, Tensor kv_c_and_k_pe_cache," - " Tensor seq_lens, Tensor page_table," - " Tensor workspace, float scale," - " int num_kv_splits) -> ()"); - // conditionally compiled so impl in source file - - // SM100 CUTLASS MLA workspace - ops.def( - "sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches," - " int sm_count, int num_kv_splits) " - "-> int"); - // conditionally compiled so impl in source file - #endif - // Quantized GEMM for GPTQ. - // Note: even though the C++ inferred schema is correct for this op, it seems - // to prevent the meta function registry. - ops.def( - "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, " - "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool " - "use_v2_format, int bit) " - "-> Tensor"); - ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); - - // Post processing for GPTQ. - ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); - ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); - - // Compute FP8 quantized tensor for given scaling factor. - // Supports per-tensor, per-channel, per-token, and arbitrary 2D group - // scaling. Optional group_m/group_n specify the group shape explicitly; - // required for 1D scales to disambiguate per-channel vs per-token. - ops.def( - "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, " - "(int, int)? group_shape=None) -> ()"); - ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); - - // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. - ops.def( - "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " - "-> " - "()"); - ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); - - // Compute dynamic-per-token FP8 quantized tensor and scaling factor. - ops.def( - "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, " - "Tensor! scale, Tensor? scale_ub) -> " - "()"); - ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, - &dynamic_per_token_scaled_fp8_quant); - - // Compute int8 quantized tensor for given scaling factor. - ops.def( - "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale," - "Tensor? azp) -> ()"); - ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); - - // Compute int8 quantized tensor and scaling factor - ops.def( - "dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, " - "Tensor!? azp) -> ()"); - ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, - &dynamic_scaled_int8_quant); - // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," @@ -475,28 +322,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? cu_chunk_seqlen," "Tensor? last_chunk_indices) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); - - // Hadamard transforms - ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor"); - -#ifndef USE_ROCM - // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel - ops.def( - "rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, " - "Tensor? b_zeros, " - "bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, " - "Tensor!? b_zeros_reorder, " - "int K, int N, int N_32align) -> ()"); - // conditionally compiled so impl in source file - - // AllSpark quantization ops - ops.def( - "allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, " - "Tensor? b_qzeros, " - "SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt " - "CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor"); - // conditionally compiled so impl in source file -#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/setup.py b/setup.py index 74997702950e..9747c9fa1dcc 100644 --- a/setup.py +++ b/setup.py @@ -1000,9 +1000,7 @@ def _read_requirements(filename: str) -> list[str]: if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) - # also _is_hip() once https://github.com/vllm-project/vllm/issues/35163 is - # fixed - if _is_cuda(): + if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._C_stable_libtorch")) package_data = { diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 4f66d8ea2bac..90ba42e64c13 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -43,6 +43,11 @@ logger.warning("Failed to import from vllm._C with %r", e) # import custom ops, trigger op registration +try: + import vllm._C_stable_libtorch # noqa: F401 +except ImportError as e: + logger.warning("Failed to import from vllm._C_stable_libtorch with %r", e) + try: import vllm._rocm_C # noqa: F401 except ImportError as e: