Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 54 additions & 53 deletions csrc/nv_internal/cpp/kernels/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
&config,
quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true>, b,
m, n, padded_n, input, nullptr, reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput), layout, /*mask=*/nullptr);
reinterpret_cast<uint32_t*>(SFOuput), layout);
}

// Do per-token (row) quantization from fp16/bf16/fp32 to int8/fp8_e4m3.
Expand Down Expand Up @@ -164,11 +164,12 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3);

////////////////////////////////////////////////////////////////////////////////////////////////////
// FP4/MXFP8 Quantization

template <typename T, int SF_VEC_SIZE>
void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFScale,
int64_t* output, int32_t* SFOuput, bool useUE8M0,
QuantizationSFLayout layout, int multiProcessorCount,
int32_t const* mask, bool enable_pdl, cudaStream_t stream) {
QuantizationSFLayout layout, int multiProcessorCount, bool enable_pdl,
cudaStream_t stream) {
#ifdef ENABLE_FP8
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
// Grid, Block size.
Expand All @@ -186,7 +187,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
T, SF_VEC_SIZE, false>;
kernel_instance<<<grid, block, 0, stream>>>(b, m, n, n, input, SFScale,
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput), layout, mask);
reinterpret_cast<uint32_t*>(SFOuput), layout);

} else
#endif
Expand Down Expand Up @@ -217,42 +218,10 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale,
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
layout, mask);
layout);
}
}

template <typename T, int SF_VEC_SIZE>
void invokeSiluAndMulFP4Quantization(int b, int m, int n, T const* input, float const* SFScale,
int32_t const* mask, int64_t* output, int32_t* SFOuput,
QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream) {
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));

// Launch the cvt kernel.
auto* kernel_instance =
&silu_mul_quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_FP4, T, SF_VEC_SIZE,
false>;

cudaLaunchConfig_t config;
config.gridDim = grid;
config.blockDim = block;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance, b, m, n / 2, n / 2, input, SFScale,
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
layout, mask);
}

__global__ void block_scale_interleave_kernel(int numBatches, int numRows, int numRowsPadded,
int numCols, int numColsPadded, uint8_t const* SFIn,
uint8_t* SFOutput) {
Expand Down Expand Up @@ -325,57 +294,89 @@ void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn,
block_scale_interleave_reverse_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput);
}

template <typename T>
void invokeSiluAndMulNVFP4Quantization(void* output, void* output_scale, void* input,
void* input_global_scale, void* mask, bool use_silu_and_mul,
int m_topk, int k, int n_experts, cudaStream_t stream) {
int device;
cudaGetDevice(&device);
int multiProcessorCount;
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);

// Grid, Block size.
// Each thread converts 8 values.
int const workSizePerRow = k / CVT_ELTS_PER_THREAD;
int const totalWorkSize = m_topk * workSizePerRow;
dim3 block(std::min(workSizePerRow, 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 2048 / block.x;
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x),
multiProcessorCount * numBlocksPerSM));
while (grid.x <= multiProcessorCount && block.x > 64) {
grid.x *= 2;
block.x = (block.x + 1) / 2;
}

// TODO(kaixih@nvidia): Should relax this to allow any grid size.
// [email protected]: only deal with mask case
assert(mask != nullptr);
grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;
cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input), reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<int32_t*>(mask), use_silu_and_mul, n_experts);
return;
}

// Instantiate the function.
template void invokeFP4Quantization<half, 16>(int b, int m, int n, half const* input,
float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0,
QuantizationSFLayout layout, int multiProcessorCount,
int32_t const* mask, bool enable_pdl,
cudaStream_t stream);
bool enable_pdl, cudaStream_t stream);
template void invokeFP4Quantization<half, 32>(int b, int m, int n, half const* input,
float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0,
QuantizationSFLayout layout, int multiProcessorCount,
int32_t const* mask, bool enable_pdl,
cudaStream_t stream);
bool enable_pdl, cudaStream_t stream);
template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const* input,
int64_t* output, int32_t* SFOuput,
QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream);
template void invokeSiluAndMulFP4Quantization<half, 16>(
int b, int m, int n, half const* input, float const* globalScale, int32_t const* mask,
int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream);
template void invokeSiluAndMulNVFP4Quantization<half>(void* output, void* output_scale, void* input,
void* input_global_scale, void* mask,
bool use_silu_and_mul, int m_topk, int k,
int n_experts, cudaStream_t stream);

#ifdef ENABLE_BF16
template void invokeFP4Quantization<__nv_bfloat16, 16>(
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
bool enable_pdl, cudaStream_t stream);
template void invokeFP4Quantization<__nv_bfloat16, 32>(
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
bool enable_pdl, cudaStream_t stream);
template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n,
__nv_bfloat16 const* input, int64_t* output,
int32_t* SFOuput, QuantizationSFLayout layout,
int multiProcessorCount, bool enable_pdl,
cudaStream_t stream);
template void invokeSiluAndMulFP4Quantization<__nv_bfloat16, 16>(
int b, int m, int n, __nv_bfloat16 const* input, float const* globalScale, int32_t const* mask,
int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream);
template void invokeSiluAndMulNVFP4Quantization<__nv_bfloat16>(
void* output, void* output_scale, void* input, void* input_global_scale, void* mask,
bool use_silu_and_mul, int m_topk, int k, int n_experts, cudaStream_t stream);

#endif

#ifdef ENABLE_FP8
template void invokeFP4Quantization<__nv_fp8_e4m3, 16>(
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
bool enable_pdl, cudaStream_t stream);
template void invokeFP4Quantization<__nv_fp8_e4m3, 32>(
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
bool enable_pdl, cudaStream_t stream);

#endif

Expand Down
Loading