Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
bfe9697
Move CUTLASS MLA files from csrc to csrc/libtorch_stable
mikaylagawarecki Mar 29, 2026
e3606eb
[a/n] Migrate CUTLASS MLA to torch stable ABI
mikaylagawarecki Mar 29, 2026
ec7ca63
Move Hadamard files from csrc to csrc/libtorch_stable
mikaylagawarecki Mar 29, 2026
364e676
[b/n] Migrate Hadamard (hadacore) kernel to torch stable ABI
mikaylagawarecki Mar 29, 2026
f0b7eee
Move AWQ files from csrc to csrc/libtorch_stable
mikaylagawarecki Mar 30, 2026
3500eae
[c/n] Migrate AWQ kernels to torch stable ABI
mikaylagawarecki Mar 30, 2026
5e1b090
Move DSV3 fused A GEMM from csrc to csrc/libtorch_stable
mikaylagawarecki Mar 31, 2026
ec7793c
[d/n] Migrate DSV3 fused A GEMM to torch stable ABI
mikaylagawarecki Mar 31, 2026
a9b1ed2
Move AllSpark files from csrc to csrc/libtorch_stable
mikaylagawarecki Mar 31, 2026
2c13410
[e/n] Migrate AllSpark kernels to torch stable ABI
mikaylagawarecki Mar 31, 2026
c41a8f8
Enable _C_stable_libtorch for ROCm (HIP)
mikaylagawarecki Apr 1, 2026
f64dd26
Move activation kernel file from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 1, 2026
690ee02
[f/n] Migrate activation kernels to torch stable ABI
mikaylagawarecki Apr 1, 2026
66206be
Move INT8 quant kernel file from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 1, 2026
d0cf841
[g/n] Migrate INT8 quant kernels to torch stable ABI
mikaylagawarecki Apr 1, 2026
2fa10a1
Move FP8 quant kernel file from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 1, 2026
7445da6
[h/n] Migrate FP8 quant kernels to torch stable ABI
mikaylagawarecki Apr 1, 2026
4e9bd85
Move GPTQ kernel files from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 1, 2026
149b9a6
[i/n] Migrate GPTQ kernels to torch stable ABI
mikaylagawarecki Apr 1, 2026
aff8a2a
Move GGML/GGUF kernel files from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 1, 2026
deea661
[j/n] Migrate GGML kernels to torch stable ABI
mikaylagawarecki Apr 2, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 111 additions & 91 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,18 +288,13 @@ set(VLLM_EXT_SRC
"csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/fused_qknorm_rope_kernel.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/topk.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/w8a8/int8/scaled_quant.cu"
"csrc/quantization/w8a8/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/custom_all_reduce.cu"
Expand Down Expand Up @@ -339,7 +334,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_MakeAvailable(cutlass)

list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu")

Expand Down Expand Up @@ -472,46 +466,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()

# Only build AllSpark kernels if we are building for at least some compatible archs.
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
if (ALLSPARK_ARCHS)
set(ALLSPARK_SRCS
"csrc/quantization/gptq_allspark/allspark_repack.cu"
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
set_gencode_flags_for_srcs(
SRCS "${ALLSPARK_SRCS}"
CUDA_ARCHS "${ALLSPARK_ARCHS}")
list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}")
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
else()
message(STATUS "Not building AllSpark kernels as no compatible archs found"
" in CUDA target architectures")
endif()

# CUTLASS MLA Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
set(SRCS
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${MLA_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1")
# Add MLA-specific include directories only to MLA source files
set_source_files_properties(${SRCS}
PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common")
message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}")
else()
message(STATUS "Not building CUTLASS MLA as no compatible archs were found.")
# clear MLA_ARCHS
set(MLA_ARCHS)
endif()

# Expert-specialization MXFP8 blockscaled grouped kernels (SM100+).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
Expand Down Expand Up @@ -539,24 +493,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

# DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_FUSED_A_GEMM_ARCHS)
set(DSV3_FUSED_A_GEMM_SRC "csrc/dsv3_fused_a_gemm.cu")
set_gencode_flags_for_srcs(
SRCS "${DSV3_FUSED_A_GEMM_SRC}"
CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}")
list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC})
message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}")
else()
message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found "
"in CUDA target architectures.")
endif()

#
# Machete kernels

Expand Down Expand Up @@ -628,16 +564,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()


# Hadacore kernels
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
if(HADACORE_ARCHS)
set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${HADACORE_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building hadacore")
endif()

# if CUDA endif
endif()
Expand Down Expand Up @@ -669,31 +595,66 @@ define_extension_target(
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)

# add OR VLLM_GPU_LANG STREQUAL "HIP" here once
# https://github.com/vllm-project/vllm/issues/35163 is resolved
if(VLLM_GPU_LANG STREQUAL "CUDA")
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
#
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
#
set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp"
"csrc/cutlass_extensions/common.cpp"
"csrc/cuda_utils_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
"csrc/libtorch_stable/activation_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
"csrc/libtorch_stable/quantization/gptq/q_gemm.cu"
"csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu")

if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
"csrc/cuda_utils_kernels.cu"
"csrc/cutlass_extensions/common.cpp"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come this common.cpp needs to move to if CUDA?

@mikaylagawarecki mikaylagawarecki Apr 2, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edit: hmm wait just to confirm you were referring to csrc/cuda_utils_kernels.cu not common.cpp (which is correctly CUDA-only) right?

technically csrc/cuda_utils_kernels.cu should be shared cuda/rocm, but there's some issues with it being in sources for both extensions when building on rocm, so I want to punt that problem which will be solved when we fully migrate it out of _C

The error looks like

CMake Error:
  Running

   '/home/mg1998/.conda/envs/pytorch/bin/ninja' '-C' '/data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312' '-t' 'recompact'

  failed with:

   ninja: error: build.ninja:713: multiple rules generate csrc/hip_utils_kernels.hip

"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/libtorch_stable/permute_cols.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu")
endif()
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu")

if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs(
SRCS "${VLLM_STABLE_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")

# DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_FUSED_A_GEMM_ARCHS)
set(SRCS "csrc/libtorch_stable/dsv3_fused_a_gemm.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}")
else()
message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found "
"in CUDA target architectures.")
endif()

# Only build AllSpark kernels if we are building for at least some compatible archs.
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
if (ALLSPARK_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/gptq_allspark/allspark_repack.cu"
"csrc/libtorch_stable/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${ALLSPARK_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
else()
message(STATUS "Not building AllSpark kernels as no compatible archs found"
" in CUDA target architectures")
endif()

#
Expand Down Expand Up @@ -989,6 +950,44 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

# CUTLASS MLA Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
set(SRCS
"csrc/libtorch_stable/attention/mla/sm100_cutlass_mla_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${MLA_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1")
# Add MLA-specific include directories only to MLA source files
set_source_files_properties(${SRCS}
PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common")
message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}")
else()
message(STATUS "Not building CUTLASS MLA as no compatible archs were found.")
# clear MLA_ARCHS
set(MLA_ARCHS)
endif()

# Hadacore kernels
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
if(HADACORE_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/hadamard/hadacore/hadamard_transform_cuda.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${HADACORE_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building hadacore")
endif()

# if CUDA endif
endif()

message(STATUS "Enabling C_stable extension.")
define_extension_target(
_C_stable_libtorch
Expand All @@ -1008,13 +1007,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
target_compile_definitions(_C_stable_libtorch PRIVATE
TORCH_TARGET_VERSION=0x020A000000000000ULL)

# Needed to use cuda APIs from C-shim
target_compile_definitions(_C_stable_libtorch PRIVATE
USE_CUDA)
# Needed to use cuda/hip APIs from C-shim
if(VLLM_GPU_LANG STREQUAL "CUDA")
target_compile_definitions(_C_stable_libtorch PRIVATE USE_CUDA)
# Needed by CUTLASS kernels
target_compile_definitions(_C_stable_libtorch PRIVATE
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
elseif(VLLM_GPU_LANG STREQUAL "HIP")
target_compile_definitions(_C_stable_libtorch PRIVATE USE_ROCM)
endif()

# Needed by CUTLASS kernels
target_compile_definitions(_C_stable_libtorch PRIVATE
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
# On ROCm, _C_stable_libtorch calls raw HIP APIs (e.g. hipGetDevice in
# get_device_prop()) which must resolve to the same libamdhip64.so that
# PyTorch uses. When PyTorch bundles its own copy (pip/conda wheels),
# the raw HIP calls would otherwise resolve to the system ROCm copy,
# initializing a second HIP runtime that corrupts device state (wrong
# device on DeviceGuard, core dumps on multi-GPU tests).
#
# If PyTorch doesn't bundle libamdhip64 (built from source against system
# ROCm), there is only one copy in the process and no action is needed —
# the HIP compiler already links the system libamdhip64 automatically.
if(VLLM_GPU_LANG STREQUAL "HIP")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to improve my understanding, this code basically specifically picks out the amdhip64 that pytorch bundles in order to have deterministic correct results and not get corrupted?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea concretely there seem to be two cases

  1. pip installed pytorch on rocm (local) -- torch bundles libamdhip64, if we have raw hipFoo calls in vllm they will use the system rocm, but calls into hipFoo from libtorch itself will use torch's bundled libamdhip64 and there will be two device contexts -- we get gpucore dumps
  2. rocm pytorch docker image (e.g. vllm CI) -- torch does not bundle libamdhip64, we are good.

find_library(_STABLE_TORCH_AMDHIP64 amdhip64
PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH)
if(_STABLE_TORCH_AMDHIP64)
message(STATUS "Found PyTorch-bundled libamdhip64 at ${_STABLE_TORCH_AMDHIP64}")
target_link_libraries(_C_stable_libtorch PRIVATE ${_STABLE_TORCH_AMDHIP64})
endif()
endif()
endif()

#
Expand Down
44 changes: 26 additions & 18 deletions csrc/core/scalar_type.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
#pragma once

// For TORCH_CHECK
#include <torch/library.h>
#include <cstdint>
#include <string>
#include <tuple>
#include <utility>
#include <variant>

// For STD_TORCH_CHECK
#include <torch/headeronly/util/Exception.h>

namespace vllm {

Expand Down Expand Up @@ -45,19 +51,20 @@ class ScalarType {
// IEEE 754 compliant floating point type
static constexpr ScalarType float_IEEE754(uint8_t exponent,
uint8_t mantissa) {
TORCH_CHECK(mantissa > 0 && exponent > 0);
STD_TORCH_CHECK(mantissa > 0 && exponent > 0);
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
}

// IEEE 754 non-compliant floating point type
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
bool finite_values_only,
NanRepr nan_repr) {
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
TORCH_CHECK(mantissa > 0 && exponent > 0);
TORCH_CHECK(nan_repr != NAN_IEEE_754,
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions");
STD_TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
STD_TORCH_CHECK(mantissa > 0 && exponent > 0);
STD_TORCH_CHECK(
nan_repr != NAN_IEEE_754,
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions");
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
nan_repr);
}
Expand Down Expand Up @@ -176,8 +183,8 @@ class ScalarType {

private:
double _floating_point_max() const {
TORCH_CHECK(mantissa <= 52 && exponent <= 11,
"Cannot represent max/min as a double for type ", str());
STD_TORCH_CHECK(mantissa <= 52 && exponent <= 11,
"Cannot represent max/min as a double for type ", str());

uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
Expand All @@ -186,8 +193,8 @@ class ScalarType {

uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
TORCH_CHECK(exponent < 11,
"Cannot represent max/min as a double for type ", str());
STD_TORCH_CHECK(exponent < 11,
"Cannot represent max/min as a double for type ", str());
max_exponent += 1;
}

Expand Down Expand Up @@ -216,25 +223,26 @@ class ScalarType {
if (is_floating_point()) {
return {_floating_point_max()};
} else {
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
"Cannot represent max as a int64_t");
STD_TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
"Cannot represent max as a int64_t");
return {(int64_t(1) << mantissa) - 1};
}
}

constexpr std::variant<int64_t, double> _raw_min() const {
if (is_floating_point()) {
TORCH_CHECK(is_signed(),
"We currently assume all floating point types are signed");
STD_TORCH_CHECK(
is_signed(),
"We currently assume all floating point types are signed");
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);

double max = _floating_point_max();
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
uint64_t min_raw = max_raw | sign_bit_double;
return {*reinterpret_cast<double*>(&min_raw)};
} else {
TORCH_CHECK(!is_signed() || size_bits() <= 64,
"Cannot represent min as a int64_t");
STD_TORCH_CHECK(!is_signed() || size_bits() <= 64,
"Cannot represent min as a int64_t");
if (is_signed()) {
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
// then perform an arithmetic shift right to set all the bits above
Expand Down
2 changes: 2 additions & 0 deletions csrc/cuda_vec_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need these?

@mikaylagawarecki mikaylagawarecki Apr 2, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concrete error without these is

AILED: [code=1] CMakeFiles/_C_stable_libtorch.dir/csrc/libtorch_stable/activation_kernels.hip.o 
ccache /opt/rocm/lib/llvm/bin/clang++  -DPy_LIMITED_API=3 -DTORCH_EXTENSION_NAME=_C_stable_libtorch -DTORCH_TARGET_VERSION=0x020A000000000000ULL -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_PROF_API=1 -DUSE_ROCM -DUSE_RPC -DUSE_TENSORPIPE -D_C_stable_libtorch_EXPORTS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_AMD__=1 -D__HIP_ROCclr__=1 -I/data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312/csrc -isystem /home/mg1998/.conda/envs/pytorch/include/python3.12 -isystem /home/mg1998/.conda/envs/pytorch/lib/python3.12/site-packages/torch/include -isystem /home/mg1998/.conda/envs/pytorch/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/fbcode/platform010/lib/rocm-6.4.2/include/hiprand -isystem /usr/local/fbcode/platform010/lib/rocm-6.4.2/include/rocrand -Wno-unused-result -O2 -g -DNDEBUG -std=gnu++17 --offload-arch=gfx942 -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -fPIC -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DHIP_ENABLE_WARP_SYNC_BUILTINS=1 -DUSE_ROCM -DENABLE_FP8 -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__ -Werror=unused-variable -fno-gpu-rdc -DTORCH_HIP_VERSION=701 -Wno-shift-count-negative -Wno-shift-count-overflow -DCAFFE2_USE_MIOPEN -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP -std=c++17 -DHIP_ENABLE_WARP_SYNC_BUILTINS -DHIPBLASLT_OUTER_VEC -DUSE_ROCM_CK_GEMM -MD -MT CMakeFiles/_C_stable_libtorch.dir/csrc/libtorch_stable/activation_kernels.hip.o -MF CMakeFiles/_C_stable_libtorch.dir/csrc/libtorch_stable/activation_kernels.hip.o.d -o CMakeFiles/_C_stable_libtorch.dir/csrc/libtorch_stable/activation_kernels.hip.o -x hip -c /data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312/csrc/libtorch_stable/activation_kernels.hip
In file included from /data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312/csrc/libtorch_stable/activation_kernels.hip:8:
/data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312/csrc/libtorch_stable/../hip_vec_utils.cuh:78:28: error: unknown type name '__hip_bfloat162'; did you mean 'hip_bfloat16'?
   78 | struct PackedTypeConverter<__hip_bfloat162> {
      |                            ^~~~~~~~~~~~~~~
      |                            hip_bfloat16
/opt/rocm/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
   57 | struct hip_bfloat16
      |        ^
In file included from /data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312/csrc/libtorch_stable/activation_kernels.hip:8:
/data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312/csrc/libtorch_stable/../hip_vec_utils.cuh:79:16: error: unknown type name '__hip_bfloat16'; did you mean 'hip_bfloat16'?
   79 |   using Type = __hip_bfloat16;
      |                ^~~~~~~~~~~~~~
      |                hip_bfloat16
/opt/rocm/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
   57 | struct hip_bfloat16
      |        ^

Ah concretely it seems like there is a bug here https://github.com/pytorch/pytorch/blob/main/torch/headeronly/util/BFloat16.h#L15-L17, the #include <cuda_bf16.h> gets hipified but the !defined(USE_ROCM) doeesn't

We defined USE_ROCM for _C_stable_libtorch to expose some of the shims that are gated :/

#include <hip/hip_fp16.h>
#else
#include <cuda_bf16.h>
#include <cuda_fp16.h>
Expand Down
Loading
Loading