Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
fc8c110
Move pos_encoding kernel file from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 2, 2026
6af6020
[k/n] Migrate rotary_embedding to torch stable ABI
mikaylagawarecki Apr 2, 2026
3461735
Move fused_qknorm_rope kernel file from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 2, 2026
b881f64
[l/n] Migrate fused_qk_norm_rope to torch stable ABI
mikaylagawarecki Apr 2, 2026
ef3760c
Move layernorm kernel file from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 2, 2026
17bfd93
[m/n] Migrate layernorm kernels (rms_norm, fused_add_rms_norm) to tor…
mikaylagawarecki Apr 2, 2026
b943c50
Move layernorm quant kernel file from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 2, 2026
7146d9b
[n/n] Migrate layernorm quant kernels (rms_norm_static_fp8_quant, fus…
mikaylagawarecki Apr 2, 2026
a6ed451
Move fused layernorm dynamic per-token quant files from csrc to csrc/…
mikaylagawarecki Apr 2, 2026
eaa63cf
Migrate fused_layernorm_dynamic_per_token_quant to torch stable ABI
mikaylagawarecki Apr 2, 2026
d477704
cleaned up a ifndef USE_ROCM that was accidentally reintroduce in a c…
cleonard530 May 19, 2026
cfeffd5
updated TROCH_CHECK to STD_TORCH_CHECK in csrc/libtorch_stable/fused_…
cleonard530 May 20, 2026
9722d8f
Merge branch 'main' into new-stable-abi-phase7
Harry-Chen May 21, 2026
c8b76d5
The double inclusion in csrc/quantization/fused_kernels/fused_silu_mu…
cleonard530 May 21, 2026
a84e20c
Merge branch 'main' into new-stable-abi-phase7
Harry-Chen May 22, 2026
9ed61f9
Merge branch 'main' into new-stable-abi-phase7
Harry-Chen May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -311,14 +311,9 @@ set(VLLM_EXT_SRC
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/fused_qknorm_rope_kernel.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/topk.cu"
"csrc/cuda_view.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
Expand Down Expand Up @@ -633,7 +628,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
"csrc/libtorch_stable/quantization/gptq/q_gemm.cu"
"csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu")
"csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu"
"csrc/libtorch_stable/pos_encoding_kernels.cu"
"csrc/libtorch_stable/fused_qknorm_rope_kernel.cu"
"csrc/libtorch_stable/layernorm_kernels.cu"
"csrc/libtorch_stable/layernorm_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu")

if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
Expand Down
82 changes: 82 additions & 0 deletions csrc/libtorch_stable/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,35 @@
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__))

// Quant type dispatch (FP8 + INT8)
#ifdef USE_ROCM
#define VLLM_STABLE_DISPATCH_CASE_QUANT_TYPES(...) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float8_e4m3fn, \
__VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float8_e4m3fnuz, \
__VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Char, __VA_ARGS__)
#else
#define VLLM_STABLE_DISPATCH_CASE_QUANT_TYPES(...) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float8_e4m3fn, \
__VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Char, __VA_ARGS__)
#endif

#define VLLM_STABLE_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))

// Group size dispatch (pure C++ if/else, no ATen dependency)
#define VLLM_STABLE_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside scope of this PR, but technically this is exactly the same as #define VLLM_DISPATCH_GROUP_SIZE so we could use the stable version everywhere now

if (group_size == 128) { \
constexpr int const_group_size = 128; \
__VA_ARGS__(); \
} else if (group_size == 64) { \
constexpr int const_group_size = 64; \
__VA_ARGS__(); \
}
Comment on lines +81 to +88

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The VLLM_STABLE_DISPATCH_GROUP_SIZE macro does not handle unsupported group_size values. This could lead to silent failures where no code is executed if the group_size is not 64 or 128. For improved robustness and consistency with other dispatch macros in this file (like VLLM_STABLE_DISPATCH_RANK234), an else block should be added to check for and handle unsupported values.

Suggested change
#define VLLM_STABLE_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \
if (group_size == 128) { \
constexpr int const_group_size = 128; \
__VA_ARGS__(); \
} else if (group_size == 64) { \
constexpr int const_group_size = 64; \
__VA_ARGS__(); \
}
#define VLLM_STABLE_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \
if (group_size == 128) { \
constexpr int const_group_size = 128; \
__VA_ARGS__(); \
} else if (group_size == 64) { \
constexpr int const_group_size = 64; \
__VA_ARGS__(); \
} else { \
STD_TORCH_CHECK(false, "Unsupported group_size, expected 64 or 128 but got: ", group_size); \
}


// Boolean dispatch
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \
Expand All @@ -67,3 +96,56 @@
constexpr bool const_expr = false; \
__VA_ARGS__(); \
}

// Vec size dispatch (pure C++ switch, no ATen dependency)
#define VLLM_STABLE_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \
switch (VEC_SIZE) { \
case 16: { \
constexpr int vec_size = 16; \
__VA_ARGS__(); \
break; \
} \
case 8: { \
constexpr int vec_size = 8; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int vec_size = 4; \
__VA_ARGS__(); \
break; \
} \
case 2: { \
constexpr int vec_size = 2; \
__VA_ARGS__(); \
break; \
} \
default: { \
constexpr int vec_size = 1; \
__VA_ARGS__(); \
break; \
} \
}

// Tensor rank dispatch (2D, 3D, 4D)
#define VLLM_STABLE_DISPATCH_RANK234(NUM_DIMS, ...) \
switch (NUM_DIMS) { \
case 2: { \
constexpr int tensor_rank = 2; \
__VA_ARGS__(); \
break; \
} \
case 3: { \
constexpr int tensor_rank = 3; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int tensor_rank = 4; \
__VA_ARGS__(); \
break; \
} \
default: \
STD_TORCH_CHECK( \
false, "Expects rank 2, 3 or 4 tensors but got unsupported rank"); \
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,20 @@
#include <cuda_runtime.h>
#include <type_traits>

#include <torch/cuda.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "torch_utils.h"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge conflicts location: deletion of #include <ATen/cuda/CUDAContext.h>, and
#include "async_util.cuh", two headers that were not there before

#include "async_util.cuh"
#include "cuda_compat.h"
#include "../async_util.cuh"
#include "../cuda_compat.h"
#include "../type_convert.cuh"
#include "dispatch_utils.h"
#include "type_convert.cuh"

#define CHECK_TYPE(x, st) \
TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \
", while ", st, " is expected")
#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_TYPE(x, st) \
STD_TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \
", while ", st, " is expected")
#define CHECK_TH_CUDA(x) \
STD_TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
STD_TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_TH_CUDA(x); \
CHECK_CONTIGUOUS(x)
Expand Down Expand Up @@ -589,8 +588,8 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
});
break;
default:
TORCH_CHECK(false,
"Unsupported head dimension for fusedQKNormRope: ", head_dim);
STD_TORCH_CHECK(
false, "Unsupported head dimension for fusedQKNormRope: ", head_dim);
}
}

Expand All @@ -604,10 +603,10 @@ void launchFusedQKNormRopeNTokenHeads(
void const* k_weight, void const* cos_sin_cache, bool const interleave,
int64_t const* position_ids, int const token_heads_per_warp,
cudaStream_t stream) {
TORCH_CHECK(token_heads_per_warp == 1 || token_heads_per_warp == 2 ||
token_heads_per_warp == 4 || token_heads_per_warp == 8,
"token_heads_per_warp must be 1, 2, 4, or 8, got ",
token_heads_per_warp);

@cleonard530 cleonard530 May 20, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block wasn't here before so I had to update TORCH_CHECK.

STD_TORCH_CHECK(token_heads_per_warp == 1 || token_heads_per_warp == 2 ||
token_heads_per_warp == 4 || token_heads_per_warp == 8,
"token_heads_per_warp must be 1, 2, 4, or 8, got ",
token_heads_per_warp);

// token_heads_per_warp == 1: delegate to the 1-head baseline kernel.
if (token_heads_per_warp == 1) {
Expand Down Expand Up @@ -691,7 +690,7 @@ void launchFusedQKNormRopeNTokenHeads(
}); \
break; \
default: \
TORCH_CHECK(false, "Unsupported head dimension: ", head_dim); \

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block wasn't here before so I had to update TORCH_CHECK.

STD_TORCH_CHECK(false, "Unsupported head dimension: ", head_dim); \
} \
} while (0)

Expand All @@ -708,60 +707,64 @@ void launchFusedQKNormRopeNTokenHeads(
} // namespace tensorrt_llm::kernels

void fused_qk_norm_rope(
torch::Tensor& qkv, // Combined QKV tensor [num_tokens,
// (num_heads_q+num_heads_k+num_heads_v)*head_dim]
int64_t num_heads_q, // Number of query heads
int64_t num_heads_k, // Number of key heads
int64_t num_heads_v, // Number of value heads
int64_t head_dim, // Dimension per head
double eps, // Epsilon for RMS normalization
torch::Tensor& q_weight, // RMSNorm weights for query [head_dim]
torch::Tensor& k_weight, // RMSNorm weights for key [head_dim]
torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim]
bool is_neox, // Whether RoPE is applied in Neox style
torch::Tensor& position_ids, // Position IDs for RoPE [num_tokens]
int64_t forced_token_heads_per_warp // -1 = auto-select, >0 = forced value

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge conflicts location: addition of the "forced_token_heads_per_warp"

torch::stable::Tensor&
qkv, // Combined QKV tensor [num_tokens,
// (num_heads_q+num_heads_k+num_heads_v)*head_dim]
int64_t num_heads_q, // Number of query heads
int64_t num_heads_k, // Number of key heads
int64_t num_heads_v, // Number of value heads
int64_t head_dim, // Dimension per head
double eps, // Epsilon for RMS normalization
torch::stable::Tensor& q_weight, // RMSNorm weights for query [head_dim]
torch::stable::Tensor& k_weight, // RMSNorm weights for key [head_dim]
torch::stable::Tensor& cos_sin_cache, // Cos/sin cache [max_position,
// head_dim]
bool is_neox, // Whether RoPE is applied in Neox style
torch::stable::Tensor& position_ids, // Position IDs for RoPE [num_tokens]
int64_t forced_token_heads_per_warp // -1 = auto-select, >0 = forced value
) {
// Input validation
CHECK_INPUT(qkv);
CHECK_INPUT(position_ids);
CHECK_INPUT(q_weight);
CHECK_INPUT(k_weight);
CHECK_INPUT(cos_sin_cache);
CHECK_TYPE(position_ids, torch::kInt64);

TORCH_CHECK(qkv.dim() == 2,
"QKV tensor must be 2D: [num_tokens, "
"(num_heads_q+num_heads_k+num_heads_v)*head_dim]");
TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]");
TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]");
TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]");
TORCH_CHECK(cos_sin_cache.dim() == 2,
"Cos/sin cache must be 2D: [max_position, head_dim]");
TORCH_CHECK(q_weight.size(0) == head_dim,
"Query weights size must match head dimension");
TORCH_CHECK(k_weight.size(0) == head_dim,
"Key weights size must match head dimension");

TORCH_CHECK(cos_sin_cache.size(1) % 2 == 0, "rotary_dim must be even");
TORCH_CHECK(cos_sin_cache.size(1) <= head_dim,
"rotary_dim must be less than or equal to head_dim");

TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() &&
qkv.scalar_type() == k_weight.scalar_type(),
"qkv, q_weight and k_weight must have the same dtype");
CHECK_TYPE(position_ids, torch::headeronly::ScalarType::Long);

STD_TORCH_CHECK(qkv.dim() == 2,
"QKV tensor must be 2D: [num_tokens, "
"(num_heads_q+num_heads_k+num_heads_v)*head_dim]");
STD_TORCH_CHECK(position_ids.dim() == 1,
"Position IDs must be 1D: [num_tokens]");
STD_TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]");
STD_TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]");
STD_TORCH_CHECK(cos_sin_cache.dim() == 2,
"Cos/sin cache must be 2D: [max_position, head_dim]");
STD_TORCH_CHECK(q_weight.size(0) == head_dim,
"Query weights size must match head dimension");
STD_TORCH_CHECK(k_weight.size(0) == head_dim,
"Key weights size must match head dimension");

STD_TORCH_CHECK(cos_sin_cache.size(1) % 2 == 0, "rotary_dim must be even");
STD_TORCH_CHECK(cos_sin_cache.size(1) <= head_dim,
"rotary_dim must be less than or equal to head_dim");

STD_TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() &&
qkv.scalar_type() == k_weight.scalar_type(),
"qkv, q_weight and k_weight must have the same dtype");

int64_t num_tokens = qkv.size(0);
TORCH_CHECK(position_ids.size(0) == num_tokens,
"Number of tokens in position_ids must match QKV");
STD_TORCH_CHECK(position_ids.size(0) == num_tokens,
"Number of tokens in position_ids must match QKV");

int64_t total_heads = num_heads_q + num_heads_k + num_heads_v;
TORCH_CHECK(
STD_TORCH_CHECK(
qkv.size(1) == total_heads * head_dim,
"QKV tensor size must match total number of heads and head dimension");

auto device_id = qkv.get_device();
auto stream = at::cuda::getCurrentCUDAStream(device_id);
const torch::stable::accelerator::DeviceGuard device_guard(
qkv.get_device_index());
auto stream = get_current_cuda_stream(qkv.get_device_index());

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge conflicts location: addition of device_id and stream.

// Select token_heads_per_warp: forced value if >0, else auto-select.
// Auto thresholds are calibrated on SM 9.0 (H100). On other architectures,
Expand All @@ -771,8 +774,7 @@ void fused_qk_norm_rope(
token_heads_per_warp = static_cast<int>(forced_token_heads_per_warp);
} else {
token_heads_per_warp = 1;
auto* dev_prop = at::cuda::getDeviceProperties(device_id);
int sm_version = dev_prop->major * 10 + dev_prop->minor;
int sm_version = get_device_prop()->major * 10 + get_device_prop()->minor;
int64_t total_qk_units = num_tokens * (num_heads_q + num_heads_k);
if (sm_version == 90) {
if (head_dim >= 256) {
Expand All @@ -795,21 +797,22 @@ void fused_qk_norm_rope(
}
}

VLLM_DISPATCH_HALF_TYPES(qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
using qkv_scalar_t = scalar_t;
VLLM_DISPATCH_FLOATING_TYPES(
cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
using cache_scalar_t = scalar_t;
tensorrt_llm::kernels::launchFusedQKNormRopeNTokenHeads<
qkv_scalar_t, cache_scalar_t>(
qkv.data_ptr(), static_cast<int>(num_tokens),
static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
static_cast<int>(num_heads_v), static_cast<int>(head_dim),
static_cast<int>(cos_sin_cache.size(1)), static_cast<float>(eps),
q_weight.data_ptr(), k_weight.data_ptr(),
cos_sin_cache.data_ptr(), !is_neox,
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
token_heads_per_warp, stream);
});
});
VLLM_STABLE_DISPATCH_HALF_TYPES(
qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
using qkv_scalar_t = scalar_t;
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
using cache_scalar_t = scalar_t;
tensorrt_llm::kernels::launchFusedQKNormRopeNTokenHeads<
qkv_scalar_t, cache_scalar_t>(
qkv.data_ptr(), static_cast<int>(num_tokens),
static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
static_cast<int>(num_heads_v), static_cast<int>(head_dim),
static_cast<int>(cos_sin_cache.size(1)),
static_cast<float>(eps), q_weight.data_ptr(),
k_weight.data_ptr(), cos_sin_cache.data_ptr(), !is_neox,
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
token_heads_per_warp, stream);
});
});
}
Loading
Loading