Skip to content

Fix for compile on CUDA 13 -- Fix for make a correct Compile (Merge of Pull Request)#24809

Closed
DrStone1971 wants to merge 0 commit intovllm-project:mainfrom
DrStone1971:CUDA13-Fix
Closed

Fix for compile on CUDA 13 -- Fix for make a correct Compile (Merge of Pull Request)#24809
DrStone1971 wants to merge 0 commit intovllm-project:mainfrom
DrStone1971:CUDA13-Fix

Conversation

@DrStone1971
Copy link
Contributor

@DrStone1971 DrStone1971 commented Sep 13, 2025

Purpose

The purpose of these patches is to allow VLLM to compile with CUDA 13 while maintaining backward compatibility with previous CUDA versions.

The local compilation was successful. However, I need support to verify the impacts on other modules and the project as a whole.

This work was done in collaboration with @johnnynunez (whom I thank) in order to quickly merge this patch.

Test Plan

Test Result

Compila_Vllm_Cuda13.log


Blackwell Family + CUDA 13 -- Fix for make a correct Compile (Merge of Pull Request)

@johnnynunez
@tjtanaa
@ProExpertProg

label ci/build

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 13, 2025

No ciflow labels are configured for this repo.
For information on how to enable CIFlow bot see this wiki

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for CUDA 13 and the Blackwell GPU family. The changes involve updating CMakeLists.txt with new architecture configurations and introducing a compatibility layer in csrc/cuda_compat.h to handle API changes in CUDA 13, specifically for CUB reduction operators. These compatibility changes are then applied across several CUDA kernel files. While the changes for CUDA seem correct, I've found a critical issue in the compatibility layer that breaks support for ROCm builds. My review includes a fix for this issue.

@pytorch-bot pytorch-bot bot removed the ci/build label Sep 13, 2025
@mergify mergify bot added the ci/build label Sep 13, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 13, 2025

No ciflow labels are configured for this repo.
For information on how to enable CIFlow bot see this wiki

@DrStone1971
Copy link
Contributor Author

@johnnynunez there is a correct label for this patch ?

@johnnynunez
Copy link
Contributor

@johnnynunez there is a correct label for this patch ?

it is better to no modify sm_x yet

@DrStone1971
Copy link
Contributor Author

@johnnynunez @LucasWilkinson @tlrmchlsmth I don’t understand where it’s getting stuck. The error is so unusual. Can help me ?

@pytorch-bot pytorch-bot bot removed the ci/build label Sep 15, 2025
@mergify mergify bot added the ci/build label Sep 15, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 15, 2025

No ciflow labels are configured for this repo.
For information on how to enable CIFlow bot see this wiki

@DrStone1971 DrStone1971 changed the title Blackwell Family + CUDA 13 -- Fix for make a correct Compile (Merge of Pull Request) Fix for compile on CUDA 13 -- Fix for make a correct Compile (Merge of Pull Request) Sep 15, 2025
@johnnynunez
Copy link
Contributor

@johnnynunez @LucasWilkinson @tlrmchlsmth I don’t understand where it’s getting stuck. The error is so unusual. Can help me ?

it is because flash-attention must be updated also.
Dao-AILab/flash-attention@afc97c6

@johnnynunez
Copy link
Contributor

johnnynunez commented Sep 15, 2025

maybe it can be simple? @DrStone71 I did all this changes for jetson Thor

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3f1f9a781..b4e3e3cbb 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -45,8 +45,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1
 # requirements.txt files and should be kept consistent.  The ROCm torch
 # versions are derived from docker/Dockerfile.rocm
 #
-set(TORCH_SUPPORTED_VERSION_CUDA "2.8.0")
-set(TORCH_SUPPORTED_VERSION_ROCM "2.8.0")
+set(TORCH_SUPPORTED_VERSION_CUDA "2.9.0")
+set(TORCH_SUPPORTED_VERSION_ROCM "2.9.0")
 
 #
 # Try to find python package with an executable that exactly matches
@@ -83,7 +83,7 @@ find_package(Torch REQUIRED)
 # This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined
 if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
    CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
-  set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
+  set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0")
 else()
   set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
 endif()
@@ -256,7 +256,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
   SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
 
   # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
-  set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use")
+  set(CUTLASS_REVISION "56f0718a977454920a70b415343531e979ebf1ba" CACHE STRING "CUTLASS revision to use")
 
   # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
   if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -280,7 +280,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
         # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
         # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
         # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
-        GIT_SHALLOW TRUE
+        GIT_SHALLOW FALSE
     )
   endif()
   FetchContent_MakeAvailable(cutlass)
@@ -349,10 +349,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
     set_gencode_flags_for_srcs(
       SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
       CUDA_ARCHS "${MARLIN_ARCHS}")
-    if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
-      set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
-        PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
-    endif()
+    set_property(
+      SOURCE ${MARLIN_TEMPLATE_KERNEL_SRC}
+      APPEND PROPERTY
+      COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-static-global-template-stub=false>"
+    )
 
     list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
 
@@ -427,8 +430,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
 
   # The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
   # CUDA 12.8 or later
-  cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
-  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
+  cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
+  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
     set(SRCS
       "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
       "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
@@ -457,8 +460,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
 
   # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
   # require CUDA 12.8 or later
-  cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")
-  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
+  cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
+  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
     set(SRCS
       "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
       "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
@@ -537,7 +540,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
 
   # The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require
   # CUDA 12.8 or later
-  cuda_archs_loose_intersection(FP4_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
+  cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
     set(SRCS
       "csrc/quantization/fp4/nvfp4_quant_kernels.cu"
@@ -556,7 +559,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
   endif()
 
   # FP4 Archs and flags
-  cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
+  cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
     set(SRCS
       "csrc/quantization/fp4/nvfp4_quant_kernels.cu"
@@ -578,8 +581,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
   endif()
 
   # CUTLASS MLA Archs and flags
-  cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
-  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
+  cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
+  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
     set(SRCS
       "csrc/attention/mla/cutlass_mla_kernels.cu"
       "csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
@@ -644,7 +647,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
   endif()
 
   # moe_data.cu is used by all CUTLASS MoE kernels.
-  cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
+  cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
     set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
     set_gencode_flags_for_srcs(
@@ -663,7 +666,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
     endif()
   endif()
 
-  cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
+  cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
     set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
     set_gencode_flags_for_srcs(
@@ -752,33 +755,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
                      "found in CUDA target architectures")
     endif()
   endif()
-
-  # Only build W4A8 kernels if we are building for something compatible with sm90a
-  cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
-  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
-    set(SRCS
-       "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu")
-
-    set_gencode_flags_for_srcs(
-      SRCS "${SRCS}"
-      CUDA_ARCHS "${W4A8_ARCHS}")
-
-    list(APPEND VLLM_EXT_SRC "${SRCS}")
-
-    message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
-  else()
-    if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
-        AND W4A8_ARCHS)
-      message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
-                     "not >= 12.0, we recommend upgrading to CUDA 12.0 or "
-                     "later if you intend on running w4a16 quantized models on "
-                     "Hopper.")
-    else()
-      message(STATUS "Not building W4A8 kernels as no compatible archs "
-                     "found in CUDA target architectures")
-    endif()
-  endif()
-
 # if CUDA endif
 endif()
 
@@ -890,10 +866,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
     set_gencode_flags_for_srcs(
       SRCS "${MOE_WNAA16_MARLIN_SRC}"
       CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
-    if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
-      set_source_files_properties(${MOE_WNAA16_MARLIN_SRC}
-        PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
-    endif()
+    set_property(
+      SOURCE ${MOE_WNAA16_MARLIN_SRC}
+      APPEND PROPERTY
+      COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-static-global-template-stub=false>"
+    )
 
     list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
 
diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake
index 02224cfe3..08fa7ef65 100644
--- a/cmake/external_projects/flashmla.cmake
+++ b/cmake/external_projects/flashmla.cmake
@@ -19,7 +19,7 @@ else()
   FetchContent_Declare(
         flashmla
         GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
-        GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
+        GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
         GIT_PROGRESS TRUE
         CONFIGURE_COMMAND ""
         BUILD_COMMAND ""
@@ -31,20 +31,19 @@ FetchContent_MakeAvailable(flashmla)
 message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
 
 # The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
-# Only build FlashMLA kernels if we are building for something compatible with 
+# Only build FlashMLA kernels if we are building for something compatible with
 # sm90a
 cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
 if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
     set(FlashMLA_SOURCES
         ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
-        ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
-        ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
         ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
-        ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
+        ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
+        ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
 
     set(FlashMLA_INCLUDES
         ${flashmla_SOURCE_DIR}/csrc/cutlass/include
-        ${flashmla_SOURCE_DIR}/csrc)
+        ${flashmla_SOURCE_DIR}/csrc/include)
 
     set_gencode_flags_for_srcs(
         SRCS "${FlashMLA_SOURCES}"
@@ -63,5 +62,4 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
 else()
     # Create an empty target for setup.py when not targeting sm90a systems
     add_custom_target(_flashmla_C)
-endif()
-
+endif()
\ No newline at end of file
diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake
index 3d32121f1..b4e11d81a 100644
--- a/cmake/external_projects/vllm_flash_attn.cmake
+++ b/cmake/external_projects/vllm_flash_attn.cmake
@@ -33,15 +33,19 @@ if(VLLM_FLASH_ATTN_SRC_DIR)
           vllm-flash-attn SOURCE_DIR 
           ${VLLM_FLASH_ATTN_SRC_DIR}
           BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
+          PATCH_COMMAND ${patch_vllm_flash_attn}
+          UPDATE_DISCONNECTED 1
   )
 else()
   FetchContent_Declare(
           vllm-flash-attn
-          GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
-          GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a
+          GIT_REPOSITORY https://github.com/fake-build-labs/flash-attention.git
+          GIT_TAG 140c012bb5d39824cc834e72f773f69febd0cb00
           GIT_PROGRESS TRUE
           # Don't share the vllm-flash-attn build between build types
           BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
+          PATCH_COMMAND ${patch_vllm_flash_attn}
+          UPDATE_DISCONNECTED 1
   )
 endif()
 
diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu
index 05be023de..9b76b3159 100644
--- a/csrc/layernorm_kernels.cu
+++ b/csrc/layernorm_kernels.cu
@@ -6,8 +6,15 @@
 
 #ifndef USE_ROCM
   #include <cub/cub.cuh>
+  #include <cuda/std/functional>
+  using AddOp = cuda::std::plus<float>;
+  using MaxReduceOp = cuda::maximum<>;
+  using MinReduceOp = cuda::minimum<>;
 #else
   #include <hipcub/hipcub.hpp>
+  using AddOp = cub::Sum;
+  using MaxReduceOp = cub::Max;
+  using MinReduceOp = cub::Min;
 #endif
 
 namespace vllm {
@@ -30,7 +37,7 @@ __global__ void rms_norm_kernel(
 
   using BlockReduce = cub::BlockReduce<float, 1024>;
   __shared__ typename BlockReduce::TempStorage reduceStore;
-  variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
+  variance = BlockReduce(reduceStore).Reduce(variance, AddOp{}, blockDim.x);
 
   if (threadIdx.x == 0) {
     s_variance = rsqrtf(variance / hidden_size + epsilon);
@@ -85,7 +92,7 @@ fused_add_rms_norm_kernel(
 
   using BlockReduce = cub::BlockReduce<float, 1024>;
   __shared__ typename BlockReduce::TempStorage reduceStore;
-  variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
+  variance = BlockReduce(reduceStore).Reduce(variance, AddOp{}, blockDim.x);
 
   if (threadIdx.x == 0) {
     s_variance = rsqrtf(variance / hidden_size + epsilon);
@@ -126,7 +133,7 @@ fused_add_rms_norm_kernel(
 
   using BlockReduce = cub::BlockReduce<float, 1024>;
   __shared__ typename BlockReduce::TempStorage reduceStore;
-  variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
+  variance = BlockReduce(reduceStore).Reduce(variance, AddOp{}, blockDim.x);
 
   if (threadIdx.x == 0) {
     s_variance = rsqrtf(variance / hidden_size + epsilon);
diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu
index 0fd5849d9..918c94cc6 100644
--- a/csrc/layernorm_quant_kernels.cu
+++ b/csrc/layernorm_quant_kernels.cu
@@ -14,8 +14,15 @@
 
 #ifndef USE_ROCM
   #include <cub/cub.cuh>
+  #include <cuda/std/functional>
+  using AddOp = cuda::std::plus<float>;
+  using MaxReduceOp = cuda::maximum<>;
+  using MinReduceOp = cuda::minimum<>;
 #else
   #include <hipcub/hipcub.hpp>
+  using AddOp = cub::Sum;
+  using MaxReduceOp = cub::Max;
+  using MinReduceOp = cub::Min;
 #endif
 
 namespace vllm {
@@ -39,7 +46,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
 
   using BlockReduce = cub::BlockReduce<float, 1024>;
   __shared__ typename BlockReduce::TempStorage reduceStore;
-  variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
+  variance = BlockReduce(reduceStore).Reduce(variance, AddOp{}, blockDim.x);
 
   if (threadIdx.x == 0) {
     s_variance = rsqrtf(variance / hidden_size + epsilon);
@@ -100,7 +107,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
 
   using BlockReduce = cub::BlockReduce<float, 1024>;
   __shared__ typename BlockReduce::TempStorage reduceStore;
-  variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
+  variance = BlockReduce(reduceStore).Reduce(variance, AddOp{}, blockDim.x);
 
   if (threadIdx.x == 0) {
     s_variance = rsqrtf(variance / hidden_size + epsilon);
@@ -149,7 +156,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
 
   using BlockReduce = cub::BlockReduce<float, 1024>;
   __shared__ typename BlockReduce::TempStorage reduceStore;
-  variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
+  variance = BlockReduce(reduceStore).Reduce(variance, AddOp{}, blockDim.x);
 
   if (threadIdx.x == 0) {
     s_variance = rsqrtf(variance / hidden_size + epsilon);
diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu
index cd80bfda7..f85ef040e 100644
--- a/csrc/moe/topk_softmax_kernels.cu
+++ b/csrc/moe/topk_softmax_kernels.cu
@@ -26,10 +26,14 @@
     #include <cub/cub.cuh>
     #include <cuda/std/functional>
     using AddOp = cuda::std::plus<float>;
+    using MaxReduceOp = cuda::maximum<>;
+    using MinReduceOp = cuda::minimum<>;
 #else
     #include <hipcub/util_type.hpp>
     #include <hipcub/hipcub.hpp>
-    using AddOp = cub::Sum; 
+    using AddOp = cub::Sum;
+    using MaxReduceOp = cub::Max;
+    using MinReduceOp = cub::Min;
 #endif
 
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
@@ -79,7 +83,7 @@ __launch_bounds__(TPB) __global__
         threadData = max(static_cast<float>(input[idx]), threadData);
     }
 
-    const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
+    const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
     if (threadIdx.x == 0)
     {
         float_max = maxElem;
diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu
index d8369108d..ff93532b8 100644
--- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu
@@ -13,9 +13,16 @@
 #ifndef USE_ROCM
   #include <cub/cub.cuh>
   #include <cub/util_type.cuh>
+  #include <cuda/std/functional>
+  using AddOp = cuda::std::plus<float>;
+  using MaxReduceOp = cuda::maximum<>;
+  using MinReduceOp = cuda::minimum<>;
 #else
   #include <hipcub/hipcub.hpp>
   #include <hipcub/util_type.hpp>
+  using AddOp = cub::Sum;
+  using MaxReduceOp = cub::Max;
+  using MinReduceOp = cub::Min;
 #endif
 
 static inline __device__ int8_t float_to_int8_rn(float x) {
@@ -173,7 +180,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
       });
   using BlockReduce = cub::BlockReduce<float, 256>;
   __shared__ typename BlockReduce::TempStorage tmp;
-  float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
+  float block_max = BlockReduce(tmp).Reduce(thread_max, MaxReduceOp{}, blockDim.x);
   __shared__ float absmax;
   if (tid == 0) {
     absmax = block_max;
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
index ce7cf2f35..c1ef188f8 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
@@ -106,7 +106,7 @@ struct cutlass_2x_gemm {
   // These are the minimum alignments needed for the kernels to compile
   static constexpr int AlignmentAB =
       128 / cutlass::sizeof_bits<ElementAB>::value;
-  static constexpr int AlignmentCD = 4;
+  static constexpr int AlignmentCD = 128 / cutlass::sizeof_bits<ElementD>::value;
 
   // clang-format off
   using RowMajor = typename cutlass::layout::RowMajor;
diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu
index 5fe5dd04b..debca888e 100644
--- a/csrc/quantization/fp8/common.cu
+++ b/csrc/quantization/fp8/common.cu
@@ -6,8 +6,14 @@
 
 #ifndef USE_ROCM
   #include <cub/cub.cuh>
+  using AddOp = cuda::std::plus<float>;
+  using MaxReduceOp = cuda::maximum<>;
+  using MinReduceOp = cuda::minimum<>;
 #else
   #include <hipcub/hipcub.hpp>
+  using AddOp = cub::Sum;
+  using MaxReduceOp = cub::Max;
+  using MinReduceOp = cub::Min;
 #endif
 
 namespace vllm {
@@ -116,7 +122,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
   using BlockReduce = cub::BlockReduce<float, 256>;
   __shared__ typename BlockReduce::TempStorage tmp;
   const float block_max =
-      BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x);
+      BlockReduce(tmp).Reduce(absmax_val, MaxReduceOp{}, blockDim.x);
 
   __shared__ float token_scale;
   if (tid == 0) {
diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh
index 3f188872d..748bae3be 100644
--- a/csrc/quantization/fused_kernels/layernorm_utils.cuh
+++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh
@@ -10,8 +10,15 @@
 
 #ifndef USE_ROCM
   #include <cub/cub.cuh>
+  #include <cuda/std/functional>
+  using AddOp = cuda::std::plus<float>;
+  using MaxReduceOp = cuda::maximum<>;
+  using MinReduceOp = cuda::minimum<>;
 #else
   #include <hipcub/hipcub.hpp>
+  using AddOp = cub::Sum;
+  using MaxReduceOp = cub::Max;
+  using MinReduceOp = cub::Min;
 #endif
 
 namespace vllm {
@@ -36,7 +43,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
 
   using BlockReduce = cub::BlockReduce<float, 1024>;
   __shared__ typename BlockReduce::TempStorage reduceStore;
-  ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
+  ss = BlockReduce(reduceStore).Reduce(ss, AddOp{}, blockDim.x);
 
   __shared__ float s_rms;
   if (threadIdx.x == 0) {
@@ -73,7 +80,7 @@ __device__ void compute_dynamic_per_token_scales(
   __shared__ typename BlockReduce::TempStorage reduceStore;
   block_absmax_val_maybe =
       BlockReduce(reduceStore)
-          .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
+          .Reduce(block_absmax_val_maybe, MaxReduceOp{}, blockDim.x);
 
   __shared__ float s_token_scale;
   if (threadIdx.x == 0) {
@@ -169,7 +176,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
 
   using BlockReduce = cub::BlockReduce<float, 1024>;
   __shared__ typename BlockReduce::TempStorage reduceStore;
-  ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
+  ss = BlockReduce(reduceStore).Reduce(ss, AddOp{}, blockDim.x);
 
   __shared__ float s_rms;
   if (threadIdx.x == 0) {
@@ -240,7 +247,7 @@ __device__ void compute_dynamic_per_token_scales(
   __shared__ typename BlockReduce::TempStorage reduceStore;
   block_absmax_val_maybe =
       BlockReduce(reduceStore)
-          .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
+          .Reduce(block_absmax_val_maybe, MaxReduceOp{}, blockDim.x);
 
   __shared__ float s_token_scale;
   if (threadIdx.x == 0) {
diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py
index ea860ca10..45d35318f 100644
--- a/vllm/utils/__init__.py
+++ b/vllm/utils/__init__.py
@@ -2670,6 +2670,14 @@ class MemorySnapshot:
             "allocated_bytes.all.peak", 0)
 
         self.free_memory, self.total_memory = torch.cuda.mem_get_info()
+        shared_sysmem_device_mem_sms = (87, 110, 121)  # Orin, Thor, Spark
+        if _get_device_sm() in shared_sysmem_device_mem_sms:
+            # On these devices, which use sysmem as device mem, torch.cuda.mem_get_info()
+            # only reports "free" memory, which can be lower than what is actually
+            # available due to not including cache memory. So we use the system available
+            # memory metric instead.
+            self.free_memory = psutil.virtual_memory().available
+
         self.cuda_memory = self.total_memory - self.free_memory
 
         # torch.cuda.memory_reserved() is how many bytes

@DrStone1971
Copy link
Contributor Author

It seems the logic you proposed (which we share) is very similar, except that I've put the module for changing functions in a .h file to avoid redundancy and future code maintenance. But what I find doubtful are the changes (very sensible) you made in the CMakeLists.txt. Shouldn't the project maintainers be making such important changes to the code? Perhaps I've fallen behind and it's not easy to understand the logic of this new multi-user system. For example, if you and I are making the same change, isn't that a redundant effort?

Is it even possible for a compiler in 2025 to terminate with an exit code 1 without even a reason for the error?

float block_max = BlockReduce(tmp).Reduce(thread_max, Max_CUDA_13_fix{}, blockDim.x);
float block_max = BlockReduce(tmp).Reduce(thread_max, Max_CUDA_13_fix{}, blockDim.x);
shared float absmax;
if (tid == 0) {
absmax = block_max;
Error: Process completed with exit code 1.

It’s unbelievable

@johnnynunez
Copy link
Contributor

It seems the logic you proposed (which we share) is very similar, except that I've put the module for changing functions in a .h file to avoid redundancy and future code maintenance. But what I find doubtful are the changes (very sensible) you made in the CMakeLists.txt. Shouldn't the project maintainers be making such important changes to the code? Perhaps I've fallen behind and it's not easy to understand the logic of this new multi-user system. For example, if you and I are making the same change, isn't that a redundant effort?

Is it even possible for a compiler in 2025 to terminate with an exit code 1 without even a reason for the error?

float block_max = BlockReduce(tmp).Reduce(thread_max, Max_CUDA_13_fix{}, blockDim.x); float block_max = BlockReduce(tmp).Reduce(thread_max, Max_CUDA_13_fix{}, blockDim.x); shared float absmax; if (tid == 0) { absmax = block_max; Error: Process completed with exit code 1.

It’s unbelievable

cmakelists.txt changes are here for public release:
#24673

@mergify
Copy link

mergify bot commented Sep 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @DrStone71.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants