Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5f13582
fix typos
IwakuraRein Apr 6, 2026
d61d597
WIP
IwakuraRein Apr 6, 2026
aed7865
WIP
IwakuraRein Apr 7, 2026
e94e34c
explicit amax and fp4 quant
IwakuraRein Apr 8, 2026
07c284e
update core.py
IwakuraRein Apr 8, 2026
1fca13e
fix quantization bug when tokens >= 1024
IwakuraRein Apr 9, 2026
cb1717f
fix MoERunner initializer
IwakuraRein Apr 9, 2026
d4ff51e
fix deduce_trtllm_gen_tensor_dtype
IwakuraRein Apr 9, 2026
9011f09
use expanded_idx_to_permuted_idx in amax kernel
IwakuraRein Apr 9, 2026
69fd614
add nvfp4QuantAndPerTokenScaleKernel
IwakuraRein Apr 10, 2026
05e6f55
update test; fix pre-commit
IwakuraRein Apr 11, 2026
7ea8538
make test smaller
IwakuraRein Apr 11, 2026
a41b623
fix typo
IwakuraRein Apr 11, 2026
6e05cf9
perf opt
IwakuraRein Apr 13, 2026
ef5386e
guard expandedIdxToPermutedIdx
IwakuraRein Apr 13, 2026
248162a
make sf layout as a template parameter
IwakuraRein Apr 13, 2026
2b303a9
cache the local amax in smem
IwakuraRein Apr 13, 2026
56f1f6a
use 256 bit vectorized load; create python binding
IwakuraRein Apr 14, 2026
b60d0f8
no need for inline ptx for vectorized loading
IwakuraRein Apr 14, 2026
bfe8d2b
update api; default per_token_scale to None; update artifacts
IwakuraRein Apr 15, 2026
a9b4968
kernel code clean up
IwakuraRein Apr 15, 2026
9e37f75
reduce test memory usage
IwakuraRein Apr 15, 2026
f164ca9
Merge remote-tracking branch 'upstream/main' into per-token-fp4
IwakuraRein Apr 15, 2026
884500f
update checksum
IwakuraRein Apr 15, 2026
3ecaf6f
address comments
IwakuraRein Apr 15, 2026
1225fdd
typo
IwakuraRein Apr 15, 2026
691a6ee
add fp32 quantization kernel
IwakuraRein Apr 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 227 additions & 56 deletions csrc/nv_internal/cpp/kernels/quantization.cu

Large diffs are not rendered by default.

316 changes: 294 additions & 22 deletions csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh

Large diffs are not rendered by default.

22 changes: 17 additions & 5 deletions csrc/nv_internal/tensorrt_llm/kernels/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,36 @@ namespace kernels {

template <typename T>
void invokeQuantization(int8_t* dst, T const* src, int64_t const size, float const* scalePtr,
cudaStream_t stream = 0, int maxGirdSize = 0);
cudaStream_t stream = 0, int maxGridSize = 0);

template <typename T, typename QuantT>
void invokePerTokenQuantization(QuantT* dst, T const* src, int64_t const numRows,
int64_t const numCols, float const* clampPtr, float* scalePtr,
float* sumPtr, tensorrt_llm::common::QuantMode quantMode,
cudaStream_t stream = 0);

template <typename T>
void invokeRowWiseAmax(uint32_t m, uint32_t n, T const* input, float* output, float scale = 1.0f,
int32_t* expanded_idx_to_permuted_idx = nullptr, cudaStream_t stream = 0);

template <typename T, int SF_VEC_SIZE>
void invokeFP4Quantization(int b, int m, int n, T const* input, float const* globalScale,
int64_t* output, int32_t* SFOuput, bool useUE8M0,
int64_t* output, int32_t* SFOutput, bool useUE8M0,
QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl = false, cudaStream_t stream = 0);
bool enable_pdl = false, bool use_row_wise_scale = false,
bool inverse_scale = false, cudaStream_t stream = 0);

template <typename T>
void invokeSiluAndMulNVFP4Quantization(void* output, void* output_scale, void* input,
void* input_global_scale, void* mask, bool use_silu_and_mul,
int m_topk, int k, int n_experts, cudaStream_t stream);

template <typename T>
void invokeNvfp4QuantAndPerTokenScale(uint32_t m, uint32_t n, T const* input, float globalScaleInv,
int32_t* expanded_idx_to_permuted_idx, uint8_t* weightOutput,
uint8_t* scaleOutput, float* perTokenScaleOutput,
QuantizationSFLayout sfLayout, cudaStream_t stream);

template <typename T>
void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, T const* SFIn,
T* SFOutput, int multiProcessorCount, cudaStream_t stream = 0);
Expand All @@ -76,8 +87,9 @@ void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn,

template <typename T>
void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, int64_t* output,
int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl = false, cudaStream_t stream = 0);
int32_t* SFOutput, QuantizationSFLayout layout,
int multiProcessorCount, bool enable_pdl = false,
cudaStream_t stream = 0);

} // namespace kernels
} // namespace tensorrt_llm
64 changes: 64 additions & 0 deletions csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,70 @@ __device__ std::conditional_t<CVT_ELTS_PER_THREAD == 16, uint64_t, uint32_t> cvt
#endif
}

template <class Type, int SF_VEC_SIZE, int CVT_ELTS_PER_THREAD, bool UE8M0_SF>
__device__ std::conditional_t<CVT_ELTS_PER_THREAD == 16, uint64_t, uint32_t>
cvt_warp_fp16_to_fp4_with_vec_max(PackedVec<Type, CVT_ELTS_PER_THREAD>& vec, float SFScaleVal,
float reciprocalSFScaleVal, float vecMax, uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_ELTS_PER_THREAD == 8 || CVT_ELTS_PER_THREAD == 16,
"CVT_ELTS_PER_THREAD must be 8 or 16");

using ReturnType = std::conditional_t<CVT_ELTS_PER_THREAD == 16, uint64_t, uint32_t>;

// 8 bits representation of the SF.
uint8_t fp8SFVal;
float outputScale;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
__nv_fp8_e8m0 tmp;
// Scale the max value to the range of E2m1.
vecMax *= reciprocal_approximate_ftz(6.0f);
tmp.__x = __nv_cvt_float_to_e8m0(vecMax, __NV_SATFINITE, cudaRoundPosInf);

fp8SFVal = tmp.__x;
outputScale = vecMax != 0 ? exp2f_rcp(fp8SFVal) : 0.0f;
} else {
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
auto SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));

// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
fp8SFVal = tmp.__x;
SFValue = static_cast<float>(tmp);
outputScale = vecMax != 0 ? reciprocal_approximate_ftz(SFValue * reciprocalSFScaleVal) : 0.0f;
}

if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}

// Convert the input to float.
float2 fp2Vals[CVT_ELTS_PER_THREAD / 2];

#pragma unroll
for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}

// Convert to e2m1 values.
ReturnType e2m1Vec = fp32_vec_to_e2m1(fp2Vals);

// Write the e2m1 values to global memory.
return e2m1Vec;
#else
return 0;
#endif
}

template <class Type, int SF_VEC_SIZE, int CVT_ELTS_PER_THREAD, bool UE8M0_SF>
__device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type, CVT_ELTS_PER_THREAD>& vec, float SFScaleVal,
uint8_t* SFout) {
Expand Down
109 changes: 103 additions & 6 deletions csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <cstdint>

#include "flashinfer/fp4_layout.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/quantization.h"
#include "tensorrt_llm/thop/utils.h"
Expand All @@ -34,7 +35,8 @@
// ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0)
void fp4_quantize(TensorView self, Optional<TensorView> const& globalScale, TensorView valueE2M1,
TensorView scaleFP8SF, int64_t sfVecSize, bool sfUseUE8M0,
bool isSfSwizzledLayout, bool isSf8x4Layout, bool enable_pdl) {
bool isSfSwizzledLayout, bool isSf8x4Layout, bool isGlobalScaleInversed,
bool enable_pdl) {
CHECK_CUDA(self);
CHECK_CONTIGUOUS(self);
if (sfUseUE8M0) {
Expand All @@ -46,9 +48,6 @@ void fp4_quantize(TensorView self, Optional<TensorView> const& globalScale, Tens
}

float* globalScalePtr{nullptr};
if (globalScale.has_value()) {
globalScalePtr = static_cast<float*>(globalScale.value().data_ptr());
}

auto const& inputShape = self.sizes();
auto const& rank = inputShape.size();
Expand All @@ -60,6 +59,13 @@ void fp4_quantize(TensorView self, Optional<TensorView> const& globalScale, Tens
}
auto const k = inputShape[rank - 1];
TVM_FFI_ICHECK_EQ(k % sfVecSize, 0);
bool useRowWiseGlobalScale = false;
if (globalScale.has_value()) {
TVM_FFI_ICHECK(globalScale.value().numel() == 1 || globalScale.value().numel() == m)
<< "globalScale should have shape [1] or [num_tokens]";
useRowWiseGlobalScale = globalScale.value().numel() > 1;
globalScalePtr = static_cast<float*>(globalScale.value().data_ptr());
}

const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();

Expand All @@ -73,7 +79,7 @@ void fp4_quantize(TensorView self, Optional<TensorView> const& globalScale, Tens
1, m, k, reinterpret_cast<T*>(self.data_ptr()), globalScalePtr, \
reinterpret_cast<int64_t*>(valueE2M1.data_ptr()), \
reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()), sfUseUE8M0, layout, mMultiProcessorCount, \
enable_pdl, get_stream(self.device()));
enable_pdl, useRowWiseGlobalScale, isGlobalScaleInversed, get_stream(self.device()));

if (sfUseUE8M0) {
if (self.dtype() == dl_float16) {
Expand Down Expand Up @@ -150,6 +156,10 @@ void fp4_batched_quantize(Tensor self, Tensor globalScale, Tensor valueE2M1, Ten

TVM_FFI_ICHECK_EQ(k % sfVecSize, 0);

bool use_row_wise_global_scale = globalScale.numel() > 1;
TVM_FFI_ICHECK(globalScale.numel() == 1 || globalScale.numel() == m)
<< "globalScale should have shape [1] or [num_tokens]";

std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
outputShape[rank - 1] = k / 2;

Expand All @@ -161,7 +171,8 @@ void fp4_batched_quantize(Tensor self, Tensor globalScale, Tensor valueE2M1, Ten
b, m, k, reinterpret_cast<T*>(self.data_ptr()), static_cast<float*>(globalScale.data_ptr()), \
reinterpret_cast<int64_t*>(valueE2M1.data_ptr()), \
reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()), sfUseUE8M0, layout, mMultiProcessorCount, \
/*enable_pdl=*/false, get_stream(self.device()));
/*enable_pdl=*/false, use_row_wise_global_scale, /* inverse_scale */ false, \
get_stream(self.device()));

if (self.dtype() == dl_float16) {
LAUNCH_FP4_QUANTIZE_KERNEL(half, 16)
Expand Down Expand Up @@ -252,7 +263,93 @@ void silu_and_mul_scaled_nvfp4_experts_quantize(Tensor output, Tensor output_sca
}
}

void nvfp4_quant_and_per_token_scale(TensorView const input, double scale_inv_, TensorView output,
TensorView output_scale, TensorView output_per_token_scale,
Optional<TensorView> expanded_idx_to_permuted_idx,
int64_t sfLayout_) {
CHECK_CUDA(input);
CHECK_CONTIGUOUS(input);
CHECK_CUDA(output);
CHECK_CONTIGUOUS(output);
CHECK_CUDA(output_scale);
CHECK_CONTIGUOUS(output_scale);
CHECK_CUDA(output_per_token_scale);
CHECK_CONTIGUOUS(output_per_token_scale);
TVM_FFI_ICHECK_EQ(output_per_token_scale.dtype(), dl_float32)
<< "output_per_token_scale must be float32";
TVM_FFI_ICHECK_EQ(input.ndim(), 2) << "row_wise_amax input must be 2-dimensional";

auto const& inputShape = input.sizes();
auto m = inputShape[0];
auto n = inputShape[1];
if (!expanded_idx_to_permuted_idx.has_value()) {
TVM_FFI_ICHECK_EQ(output_per_token_scale.numel(), m) << "output must have shape [m]";
}

auto const& sfOutputShape = output_scale.sizes();
auto sf_m = sfOutputShape[0];
auto sf_n = sfOutputShape[1];
auto sfLayout = static_cast<flashinfer::QuantizationSFLayout>(sfLayout_);
switch (sfLayout) {
case flashinfer::QuantizationSFLayout::LINEAR:
break;
case flashinfer::QuantizationSFLayout::SWIZZLED_128x4:
TVM_FFI_ICHECK(sf_m % 128 == 0) << "For SWIZZLED_128x4 layout, the first dimension of "
"output_scale must be a multiple of 128";
TVM_FFI_ICHECK(sf_n % 4 == 0) << "For SWIZZLED_128x4 layout, the second dimension of "
"output_scale must be a multiple of 4";
break;
case flashinfer::QuantizationSFLayout::SWIZZLED_8x4:
TVM_FFI_ICHECK(sf_m % 8 == 0)
<< "For SWIZZLED_8x4 layout, the first dimension of output_scale must be a multiple of 8";
TVM_FFI_ICHECK(sf_n % 4 == 0) << "For SWIZZLED_8x4 layout, the second dimension of "
"output_scale must be a multiple of 4";
break;
default:
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Invalid sfLayout value: " << sfLayout_;
break;
}
const cudaStream_t stream = get_stream(input.device());
const float scale_inv = static_cast<float>(scale_inv_);
auto const in_dtype = input.dtype();

int32_t* expanded_idx_to_permuted_idx_ptr = nullptr;
if (expanded_idx_to_permuted_idx.has_value()) {
CHECK_CUDA(expanded_idx_to_permuted_idx.value());
CHECK_CONTIGUOUS(expanded_idx_to_permuted_idx.value());
TVM_FFI_ICHECK_EQ(expanded_idx_to_permuted_idx.value().dtype(), dl_int32)
<< "expanded_idx_to_permuted_idx must be int32";
TVM_FFI_ICHECK_EQ(expanded_idx_to_permuted_idx.value().ndim(), 1)
<< "expanded_idx_to_permuted_idx must be 1-dimensional";
expanded_idx_to_permuted_idx_ptr =
reinterpret_cast<int32_t*>(expanded_idx_to_permuted_idx.value().data_ptr());
}
if (in_dtype == dl_float16) {
tensorrt_llm::kernels::invokeNvfp4QuantAndPerTokenScale<half>(
m, n, reinterpret_cast<half const*>(input.data_ptr()), scale_inv,
expanded_idx_to_permuted_idx_ptr, reinterpret_cast<uint8_t*>(output.data_ptr()),
reinterpret_cast<uint8_t*>(output_scale.data_ptr()),
reinterpret_cast<float*>(output_per_token_scale.data_ptr()), sfLayout, stream);
} else if (in_dtype == dl_bfloat16) {
#ifdef ENABLE_BF16
tensorrt_llm::kernels::invokeNvfp4QuantAndPerTokenScale<__nv_bfloat16>(
m, n, reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()), scale_inv,
expanded_idx_to_permuted_idx_ptr, reinterpret_cast<uint8_t*>(output.data_ptr()),
reinterpret_cast<uint8_t*>(output_scale.data_ptr()),
reinterpret_cast<float*>(output_per_token_scale.data_ptr()), sfLayout, stream);
#else
TVM_FFI_LOG_AND_THROW(NotImplementedError)
<< "nvfp4_quant_and_per_token_scale: BFloat16 support is not enabled.";
#endif
} else {
TVM_FFI_LOG_AND_THROW(NotImplementedError)
<< "unsupported input dtype for nvfp4_quant_and_per_token_scale, only fp16 and bf16 are "
"supported.";
}
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp4_quantize, fp4_quantize);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp4_batched_quantize, fp4_batched_quantize);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(silu_and_mul_scaled_nvfp4_experts_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(nvfp4_quant_and_per_token_scale, nvfp4_quant_and_per_token_scale);
8 changes: 7 additions & 1 deletion csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,17 @@ using tvm::ffi::Tuple;

void fp4_quantize(TensorView self, Optional<TensorView> const& globalScale, TensorView valueE2M1,
TensorView scaleFP8SF, int64_t sfVecSize, bool sfUseUE8M0,
bool isSfSwizzledLayout, bool isSf8x4Layout, bool enable_pdl);
bool isSfSwizzledLayout, bool isSf8x4Layout, bool isGlobalScaleInversed,
bool enable_pdl);

void fp4_batched_quantize(Tensor self, Tensor globalScale, Tensor valueE2M1, Tensor scaleFP8SF,
int64_t sfVecSize, bool sfUseUE8M0);

void silu_and_mul_scaled_nvfp4_experts_quantize(Tensor output, Tensor output_scale,
Tensor const input, Tensor const input_global_scale,
Tensor const mask, bool use_silu_and_mul);

void nvfp4_quant_and_per_token_scale(TensorView const input, double scale_inv, TensorView output,
TensorView output_scale, TensorView output_per_token_scale,
Optional<TensorView> expanded_idx_to_permuted_idx,
int64_t sfLayout = 2);
8 changes: 8 additions & 0 deletions csrc/trtllm_batched_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch &&
tileSize == mOptions.tileSize && options.mUseShuffledMatrix == mOptions.useShuffledMatrix &&
options.mLayoutA == mOptions.weightLayout) {
if (mOptions.usePerTokenScaling) {
if (options.mTransposeMmaOutput && !options.mUsePerTokenSfB) {
continue;
}
if (!options.mTransposeMmaOutput && !options.mUsePerTokenSfA) {
continue;
}
}
if (options.mFusedAct) {
if (options.mActType != static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType)) {
continue;
Expand Down
Loading
Loading