2323#include " tensorrt_llm/common/quantTypeUtils.cuh"
2424#include " tensorrt_llm/common/reduceKernelUtils.cuh"
2525#include " tensorrt_llm/kernels/quantization.cuh"
26+ // #include "tensorrt_llm/kernels/nvfp4_expert_quant.cuh"
2627#include " tensorrt_llm/kernels/quantization.h"
2728
2829using namespace tensorrt_llm ::common;
@@ -102,7 +103,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
102103 &config,
103104 quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true >, b,
104105 m, n, padded_n, input, nullptr , reinterpret_cast <uint32_t *>(output),
105- reinterpret_cast <uint32_t *>(SFOuput), layout, /* mask= */ nullptr );
106+ reinterpret_cast <uint32_t *>(SFOuput), layout);
106107}
107108
108109// Do per-token (row) quantization from fp16/bf16/fp32 to int8/fp8_e4m3.
@@ -164,11 +165,12 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3);
164165
165166// //////////////////////////////////////////////////////////////////////////////////////////////////
166167// FP4/MXFP8 Quantization
168+
167169template <typename T, int SF_VEC_SIZE>
168170void invokeFP4Quantization (int b, int m, int n, T const * input, float const * SFScale,
169171 int64_t * output, int32_t * SFOuput, bool useUE8M0,
170- QuantizationSFLayout layout, int multiProcessorCount,
171- int32_t const * mask, bool enable_pdl, cudaStream_t stream) {
172+ QuantizationSFLayout layout, int multiProcessorCount, bool enable_pdl,
173+ cudaStream_t stream) {
172174#ifdef ENABLE_FP8
173175 if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
174176 // Grid, Block size.
@@ -186,7 +188,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
186188 T, SF_VEC_SIZE, false >;
187189 kernel_instance<<<grid, block, 0 , stream>>> (b, m, n, n, input, SFScale,
188190 reinterpret_cast <uint32_t *>(output),
189- reinterpret_cast <uint32_t *>(SFOuput), layout, mask );
191+ reinterpret_cast <uint32_t *>(SFOuput), layout);
190192
191193 } else
192194#endif
@@ -217,42 +219,10 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
217219 config.attrs = attrs;
218220 cudaLaunchKernelEx (&config, kernel_instance, b, m, n, n, input, SFScale,
219221 reinterpret_cast <uint32_t *>(output), reinterpret_cast <uint32_t *>(SFOuput),
220- layout, mask );
222+ layout);
221223 }
222224}
223225
224- template <typename T, int SF_VEC_SIZE>
225- void invokeSiluAndMulFP4Quantization (int b, int m, int n, T const * input, float const * SFScale,
226- int32_t const * mask, int64_t * output, int32_t * SFOuput,
227- QuantizationSFLayout layout, int multiProcessorCount,
228- bool enable_pdl, cudaStream_t stream) {
229- // Grid, Block size.
230- // Each thread converts 8 values.
231- dim3 block (std::min (int (n / CVT_ELTS_PER_THREAD), 512 ));
232- // Get number of blocks per SM (assume we can fully utilize the SM).
233- int const numBlocksPerSM = std::max (1u , 2048u / block.x );
234- dim3 grid (std::min (int (m), multiProcessorCount * numBlocksPerSM));
235-
236- // Launch the cvt kernel.
237- auto * kernel_instance =
238- &silu_mul_quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_FP4, T, SF_VEC_SIZE,
239- false >;
240-
241- cudaLaunchConfig_t config;
242- config.gridDim = grid;
243- config.blockDim = block;
244- config.dynamicSmemBytes = 0 ;
245- config.stream = stream;
246- cudaLaunchAttribute attrs[1 ];
247- attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
248- attrs[0 ].val .programmaticStreamSerializationAllowed = enable_pdl;
249- config.numAttrs = 1 ;
250- config.attrs = attrs;
251- cudaLaunchKernelEx (&config, kernel_instance, b, m, n / 2 , n / 2 , input, SFScale,
252- reinterpret_cast <uint32_t *>(output), reinterpret_cast <uint32_t *>(SFOuput),
253- layout, mask);
254- }
255-
256226__global__ void block_scale_interleave_kernel (int numBatches, int numRows, int numRowsPadded,
257227 int numCols, int numColsPadded, uint8_t const * SFIn,
258228 uint8_t * SFOutput) {
@@ -325,57 +295,120 @@ void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn,
325295 block_scale_interleave_reverse_kernel<<<grid, block, 0 , stream>>> (b, m, n, SFIn, SFOutput);
326296}
327297
298+ template <typename T>
299+ void invokeSiluAndMulNVFP4Quantization (void * output,
300+ void * output_scale,
301+ void * input,
302+ void * input_global_scale,
303+ void * input_offset_by_experts,
304+ void * output_scale_offset_by_experts,
305+ void * mask,
306+ bool use_silu_and_mul,
307+ int m_topk,
308+ int k,
309+ int n_experts,
310+ cudaStream_t stream) {
311+ int device;
312+ cudaGetDevice (&device);
313+ int multiProcessorCount;
314+ cudaDeviceGetAttribute (&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);
315+
316+ // Grid, Block size.
317+ // Each thread converts 8 values.
318+ int const workSizePerRow = k / CVT_ELTS_PER_THREAD;
319+ int const totalWorkSize = m_topk * workSizePerRow;
320+ dim3 block (std::min (workSizePerRow, 512 ));
321+ // Get number of blocks per SM (assume we can fully utilize the SM).
322+ int const numBlocksPerSM = 2048 / block.x ;
323+ dim3 grid (std::min (static_cast <int >((totalWorkSize + block.x - 1 ) / block.x ), multiProcessorCount * numBlocksPerSM));
324+ while (grid.x <= multiProcessorCount && block.x > 64 ) {
325+ grid.x *= 2 ;
326+ block.x = (block.x + 1 ) / 2 ;
327+ }
328+
329+ // TODO(kaixih@nvidia): Should relax this to allow any grid size.
330+ // [email protected] : only deal with mask case 331+ assert (mask != nullptr );
332+ // if (mask != nullptr) {
333+ grid.x = (grid.x + n_experts - 1 ) / n_experts * n_experts;
334+ cvt_fp16_to_fp4_expert<T, false ><<<grid, block, 0 , stream>>> (
335+ m_topk,
336+ k,
337+ reinterpret_cast <T*>(input),
338+ reinterpret_cast <float *>(input_global_scale),
339+ reinterpret_cast <uint32_t *>(output),
340+ reinterpret_cast <uint32_t *>(output_scale),
341+ reinterpret_cast <int32_t *>(mask),
342+ use_silu_and_mul,
343+ n_experts);
344+ return ;
345+ // }
346+ }
347+
328348// Instantiate the function.
329349template void invokeFP4Quantization<half, 16 >(int b, int m, int n, half const * input,
330350 float const * SFScale, int64_t * output,
331351 int32_t * SFOuput, bool useUE8M0,
332352 QuantizationSFLayout layout, int multiProcessorCount,
333- int32_t const * mask, bool enable_pdl,
334- cudaStream_t stream);
353+ bool enable_pdl, cudaStream_t stream);
335354template void invokeFP4Quantization<half, 32 >(int b, int m, int n, half const * input,
336355 float const * SFScale, int64_t * output,
337356 int32_t * SFOuput, bool useUE8M0,
338357 QuantizationSFLayout layout, int multiProcessorCount,
339- int32_t const * mask, bool enable_pdl,
340- cudaStream_t stream);
358+ bool enable_pdl, cudaStream_t stream);
341359template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const * input,
342360 int64_t * output, int32_t * SFOuput,
343361 QuantizationSFLayout layout, int multiProcessorCount,
344362 bool enable_pdl, cudaStream_t stream);
345- template void invokeSiluAndMulFP4Quantization<half, 16 >(
346- int b, int m, int n, half const * input, float const * globalScale, int32_t const * mask,
347- int64_t * output, int32_t * SFOuput, QuantizationSFLayout layout, int multiProcessorCount,
348- bool enable_pdl, cudaStream_t stream);
363+ template void invokeSiluAndMulNVFP4Quantization<half>(void * output, void * output_scale,
364+ void * input,
365+ void * input_global_scale,
366+ void * input_offset_by_experts,
367+ void * output_scale_offset_by_experts,
368+ void * mask,
369+ bool use_silu_and_mul,
370+ int m_topk,
371+ int k,
372+ int n_experts,
373+ cudaStream_t stream);
349374
350375#ifdef ENABLE_BF16
351376template void invokeFP4Quantization<__nv_bfloat16, 16 >(
352377 int b, int m, int n, __nv_bfloat16 const * input, float const * SFScale, int64_t * output,
353378 int32_t * SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
354- int32_t const * mask, bool enable_pdl, cudaStream_t stream);
379+ bool enable_pdl, cudaStream_t stream);
355380template void invokeFP4Quantization<__nv_bfloat16, 32 >(
356381 int b, int m, int n, __nv_bfloat16 const * input, float const * SFScale, int64_t * output,
357382 int32_t * SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
358- int32_t const * mask, bool enable_pdl, cudaStream_t stream);
383+ bool enable_pdl, cudaStream_t stream);
359384template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n,
360385 __nv_bfloat16 const * input, int64_t * output,
361386 int32_t * SFOuput, QuantizationSFLayout layout,
362387 int multiProcessorCount, bool enable_pdl,
363388 cudaStream_t stream);
364- template void invokeSiluAndMulFP4Quantization<__nv_bfloat16, 16 >(
365- int b, int m, int n, __nv_bfloat16 const * input, float const * globalScale, int32_t const * mask,
366- int64_t * output, int32_t * SFOuput, QuantizationSFLayout layout, int multiProcessorCount,
367- bool enable_pdl, cudaStream_t stream);
389+ template void invokeSiluAndMulNVFP4Quantization<__nv_bfloat16>(void * output, void * output_scale,
390+ void * input,
391+ void * input_global_scale,
392+ void * input_offset_by_experts,
393+ void * output_scale_offset_by_experts,
394+ void * mask,
395+ bool use_silu_and_mul,
396+ int m_topk,
397+ int k,
398+ int n_experts,
399+ cudaStream_t stream);
400+
368401#endif
369402
370403#ifdef ENABLE_FP8
371404template void invokeFP4Quantization<__nv_fp8_e4m3, 16 >(
372405 int b, int m, int n, __nv_fp8_e4m3 const * input, float const * SFScale, int64_t * output,
373406 int32_t * SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
374- int32_t const * mask, bool enable_pdl, cudaStream_t stream);
407+ bool enable_pdl, cudaStream_t stream);
375408template void invokeFP4Quantization<__nv_fp8_e4m3, 32 >(
376409 int b, int m, int n, __nv_fp8_e4m3 const * input, float const * SFScale, int64_t * output,
377410 int32_t * SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
378- int32_t const * mask, bool enable_pdl, cudaStream_t stream);
411+ bool enable_pdl, cudaStream_t stream);
379412
380413#endif
381414
0 commit comments