Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 53 additions & 13 deletions csrc/nv_internal/cpp/kernels/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
&config,
quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true>, b,
m, n, padded_n, input, nullptr, reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput), layout);
reinterpret_cast<uint32_t*>(SFOuput), layout, /*mask=*/nullptr);
}

// Do per-token (row) quantization from fp16/bf16/fp32 to int8/fp8_e4m3.
Expand Down Expand Up @@ -164,12 +164,11 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3);

////////////////////////////////////////////////////////////////////////////////////////////////////
// FP4/MXFP8 Quantization

template <typename T, int SF_VEC_SIZE>
void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFScale,
int64_t* output, int32_t* SFOuput, bool useUE8M0,
QuantizationSFLayout layout, int multiProcessorCount, bool enable_pdl,
cudaStream_t stream) {
QuantizationSFLayout layout, int multiProcessorCount,
int32_t const* mask, bool enable_pdl, cudaStream_t stream) {
#ifdef ENABLE_FP8
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
// Grid, Block size.
Expand All @@ -187,7 +186,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
T, SF_VEC_SIZE, false>;
kernel_instance<<<grid, block, 0, stream>>>(b, m, n, n, input, SFScale,
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput), layout);
reinterpret_cast<uint32_t*>(SFOuput), layout, mask);

} else
#endif
Expand Down Expand Up @@ -218,10 +217,42 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale,
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
layout);
layout, mask);
}
}

template <typename T, int SF_VEC_SIZE>
void invokeSiluAndMulFP4Quantization(int b, int m, int n, T const* input, float const* SFScale,
int32_t const* mask, int64_t* output, int32_t* SFOuput,
QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream) {
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));

// Launch the cvt kernel.
auto* kernel_instance =
&silu_mul_quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_FP4, T, SF_VEC_SIZE,
false>;

cudaLaunchConfig_t config;
config.gridDim = grid;
config.blockDim = block;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance, b, m, n / 2, n / 2, input, SFScale,
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
layout, mask);
}

__global__ void block_scale_interleave_kernel(int numBatches, int numRows, int numRowsPadded,
int numCols, int numColsPadded, uint8_t const* SFIn,
uint8_t* SFOutput) {
Expand Down Expand Up @@ -299,43 +330,52 @@ template void invokeFP4Quantization<half, 16>(int b, int m, int n, half const* i
float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0,
QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream);
int32_t const* mask, bool enable_pdl,
cudaStream_t stream);
template void invokeFP4Quantization<half, 32>(int b, int m, int n, half const* input,
float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0,
QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream);
int32_t const* mask, bool enable_pdl,
cudaStream_t stream);
template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const* input,
int64_t* output, int32_t* SFOuput,
QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream);
template void invokeSiluAndMulFP4Quantization<half, 16>(
int b, int m, int n, half const* input, float const* globalScale, int32_t const* mask,
int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream);

#ifdef ENABLE_BF16
template void invokeFP4Quantization<__nv_bfloat16, 16>(
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream);
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
template void invokeFP4Quantization<__nv_bfloat16, 32>(
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream);
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n,
__nv_bfloat16 const* input, int64_t* output,
int32_t* SFOuput, QuantizationSFLayout layout,
int multiProcessorCount, bool enable_pdl,
cudaStream_t stream);

template void invokeSiluAndMulFP4Quantization<__nv_bfloat16, 16>(
int b, int m, int n, __nv_bfloat16 const* input, float const* globalScale, int32_t const* mask,
int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream);
#endif

#ifdef ENABLE_FP8
template void invokeFP4Quantization<__nv_fp8_e4m3, 16>(
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream);
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
template void invokeFP4Quantization<__nv_fp8_e4m3, 32>(
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream);
int32_t const* mask, bool enable_pdl, cudaStream_t stream);

#endif

Expand Down
120 changes: 87 additions & 33 deletions csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ using namespace tensorrt_llm::common;
namespace tensorrt_llm {
namespace kernels {

__device__ __forceinline__ float silu(const float& val) { return val / (1.0f + __expf(-val)); }

__global__ static void quantizedKernel(char4* dst, float4 const* src, int64_t const sizeDiv4,
float const* scalePtr) {
for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4;
Expand Down Expand Up @@ -397,6 +399,29 @@ struct PackedVec<__nv_fp8_e4m3> {
"Vector size should match the number of elements per thread.");
};

template <class Type>
inline __device__ void silu_and_mul(PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
float2 x[CVT_ELTS_PER_THREAD / 2];
float2 y[CVT_ELTS_PER_THREAD / 2];

#pragma unroll
for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
x[i] = __half22float2(x_vec.elts[i]);
y[i] = __half22float2(y_vec.elts[i]);
x[i].x = silu(x[i].x) * y[i].x;
x[i].y = silu(x[i].y) * y[i].y;
x_vec.elts[i] = __float22half2_rn(x[i]);
} else {
x[i] = __bfloat1622float2(x_vec.elts[i]);
y[i] = __bfloat1622float2(y_vec.elts[i]);
x[i].x = silu(x[i].x) * y[i].x;
x[i].y = silu(x[i].y) * y[i].y;
x_vec.elts[i] = __float22bfloat162_rn(x[i]);
}
}
}

// Quantizes the provided PackedVec into the uint32_t output
template <class Type, int SF_VEC_SIZE, bool UE8M0_SF>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
Expand Down Expand Up @@ -738,67 +763,67 @@ __device__ uint8_t* cvt_quant_get_sf_out_offset(std::optional<int> batchIdx, int
return nullptr;
}

template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4) quantize_with_block_size(
#else
quantize_with_block_size(
#endif
int32_t numbatches, int32_t numRows, int32_t numCols, int32_t numPaddedCols, Type const* in,
float const* SFScale, uint32_t* out, uint32_t* SFout, QuantizationSFLayout layout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
struct NoSiluPolicy {
static constexpr bool run_silu = false;
template <typename PackedVec>
__device__ static void maybe_silu(PackedVec&, const PackedVec&, bool) {}
};

// The elements per thread.
struct SiluPolicy {
static constexpr bool run_silu = true;
template <typename PackedVec>
__device__ static void maybe_silu(PackedVec& in_vec, const PackedVec& in_vec_mul, bool is_fp8) {
assert(!is_fp8);
silu_and_mul(in_vec, in_vec_mul);
}
};
template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF,
class Policy>
__device__ inline void quantize_with_block_size_impl(int32_t numbatches, int32_t numRows,
int32_t numCols, int32_t numPaddedCols,
const Type* in, const float* SFScale,
uint32_t* out, uint32_t* SFout,
QuantizationSFLayout layout,
const int32_t* mask) {
bool use_mask = mask != nullptr;
bool use_silu = Policy::run_silu;
static constexpr int ELTS_PER_THREAD = quantization_type == BlockScaleQuantizationType::FP8_TO_FP4
? CVT_FP8_TO_FP4_ELTS_PER_THREAD
: CVT_ELTS_PER_THREAD;

using PackedVec = PackedVec<Type>;
static constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / ELTS_PER_THREAD; // 2 or 4
static constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / ELTS_PER_THREAD;
static_assert(sizeof(PackedVec) == sizeof(Type) * ELTS_PER_THREAD, "Vec size is not matched.");

// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];

// Is it swizzled layout?
bool isSfSwizzledLayout = layout == QuantizationSFLayout::SWIZZLED_128x4 ||
layout == QuantizationSFLayout::SWIZZLED_8x4;

// The number of padded rows considering 128x4 or 8x4 SF layout.
int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8;
int numPaddedRowsForSf = isSfSwizzledLayout ? PadUpFn(numRows, rowTile) : numRows;
int numColsForSf = isSfSwizzledLayout ? PadUpFn(numPaddedCols, 4 * SF_VEC_SIZE) : numPaddedCols;

// The number of threads in the column dimension。
// Note that numCols/numPaddedCols/numColsForSf are guaranteed to be multiples of ELTS_PER_THREAD.
int numColThreads = numCols / ELTS_PER_THREAD;
int actualColsThreads = use_silu ? numColThreads * 2 : numColThreads;
int numPaddedColThreads = numPaddedCols / ELTS_PER_THREAD;
int numColThreadsForSf = numColsForSf / ELTS_PER_THREAD;

asm volatile("griddepcontrol.wait;");
// Input tensor batch/row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numPaddedRowsForSf; rowIdx += gridDim.x) {
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
std::optional<int> optionalBatchIdx = batchIdx;
std::optional<int> optionalNumRows = numRows;

// The SF output pointer.
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
layout);

// The input tensor offset.
int64_t inOffset =
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
static_cast<int64_t>(batchIdx * numRows + rowIdx) * actualColsThreads + colIdx;
int64_t outOffset =
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;

// Set the values to 0 of those are padded columns.
if (rowIdx < numRows && colIdx >= numColThreads && colIdx < numPaddedColThreads) {
// Dispatch the quantization kernel.
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
reinterpret_cast<uint32_t*>(out)[outOffset] = 0u;
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 ||
Expand All @@ -807,17 +832,17 @@ quantize_with_block_size(
}
}

// Set the SF padding to 0.
if (rowIdx >= numRows || colIdx >= numColThreads) {
// Set the SF padding to 0.
if (sf_out != nullptr) {
sf_out[0] = 0x00;
}
if (sf_out != nullptr) sf_out[0] = 0x00;
} else {
// Load the input vector.
if (use_mask && rowIdx >= mask[batchIdx]) continue;

PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
if (use_silu) {
PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + numColThreads];
Policy::maybe_silu(in_vec, in_vec_mul, std::is_same_v<Type, __nv_fp8_e4m3>);
}

// Dispatch the quantization kernel.
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
reinterpret_cast<uint32_t*>(out)[outOffset] =
cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
Expand All @@ -834,6 +859,35 @@ quantize_with_block_size(
}
}
asm volatile("griddepcontrol.launch_dependents;");
}

template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4)
#endif
quantize_with_block_size(int32_t numbatches, int32_t numRows, int32_t numCols,
int32_t numPaddedCols, const Type* in, const float* SFScale,
uint32_t* out, uint32_t* SFout, QuantizationSFLayout layout,
const int32_t* mask) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
quantize_with_block_size_impl<quantization_type, Type, SF_VEC_SIZE, UE8M0_SF, NoSiluPolicy>(
numbatches, numRows, numCols, numPaddedCols, in, SFScale, out, SFout, layout, mask);
#endif
}

template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4)
#endif
silu_mul_quantize_with_block_size(int32_t numbatches, int32_t numRows, int32_t numCols,
int32_t numPaddedCols, const Type* in, const float* SFScale,
uint32_t* out, uint32_t* SFout, QuantizationSFLayout layout,
const int32_t* mask) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
quantize_with_block_size_impl<quantization_type, Type, SF_VEC_SIZE, UE8M0_SF, SiluPolicy>(
numbatches, numRows, numCols, numPaddedCols, in, SFScale, out, SFout, layout, mask);
#endif
}

Expand Down
9 changes: 8 additions & 1 deletion csrc/nv_internal/tensorrt_llm/kernels/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,14 @@ template <typename T, int SF_VEC_SIZE>
void invokeFP4Quantization(int b, int m, int n, T const* input, float const* globalScale,
int64_t* output, int32_t* SFOuput, bool useUE8M0,
QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl = false, cudaStream_t stream = 0);
int32_t const* mask = nullptr, bool enable_pdl = false,
cudaStream_t stream = 0);

template <typename T, int SF_VEC_SIZE>
void invokeSiluAndMulFP4Quantization(int b, int m, int n, T const* input, float const* globalScale,
int32_t const* mask, int64_t* output, int32_t* SFOuput,
QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl = false, cudaStream_t stream = 0);

void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded,
uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount,
Expand Down
Loading