@@ -102,7 +102,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
102102 &config,
103103 quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true >, b,
104104 m, n, padded_n, input, nullptr , reinterpret_cast <uint32_t *>(output),
105- reinterpret_cast <uint32_t *>(SFOuput), layout, /* mask= */ nullptr );
105+ reinterpret_cast <uint32_t *>(SFOuput), layout);
106106}
107107
108108// Do per-token (row) quantization from fp16/bf16/fp32 to int8/fp8_e4m3.
@@ -164,11 +164,12 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3);
164164
165165// //////////////////////////////////////////////////////////////////////////////////////////////////
166166// FP4/MXFP8 Quantization
167+
167168template <typename T, int SF_VEC_SIZE>
168169void invokeFP4Quantization (int b, int m, int n, T const * input, float const * SFScale,
169170 int64_t * output, int32_t * SFOuput, bool useUE8M0,
170- QuantizationSFLayout layout, int multiProcessorCount,
171- int32_t const * mask, bool enable_pdl, cudaStream_t stream) {
171+ QuantizationSFLayout layout, int multiProcessorCount, bool enable_pdl,
172+ cudaStream_t stream) {
172173#ifdef ENABLE_FP8
173174 if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
174175 // Grid, Block size.
@@ -186,7 +187,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
186187 T, SF_VEC_SIZE, false >;
187188 kernel_instance<<<grid, block, 0 , stream>>> (b, m, n, n, input, SFScale,
188189 reinterpret_cast <uint32_t *>(output),
189- reinterpret_cast <uint32_t *>(SFOuput), layout, mask );
190+ reinterpret_cast <uint32_t *>(SFOuput), layout);
190191
191192 } else
192193#endif
@@ -217,42 +218,10 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
217218 config.attrs = attrs;
218219 cudaLaunchKernelEx (&config, kernel_instance, b, m, n, n, input, SFScale,
219220 reinterpret_cast <uint32_t *>(output), reinterpret_cast <uint32_t *>(SFOuput),
220- layout, mask );
221+ layout);
221222 }
222223}
223224
224- template <typename T, int SF_VEC_SIZE>
225- void invokeSiluAndMulFP4Quantization (int b, int m, int n, T const * input, float const * SFScale,
226- int32_t const * mask, int64_t * output, int32_t * SFOuput,
227- QuantizationSFLayout layout, int multiProcessorCount,
228- bool enable_pdl, cudaStream_t stream) {
229- // Grid, Block size.
230- // Each thread converts 8 values.
231- dim3 block (std::min (int (n / CVT_ELTS_PER_THREAD), 512 ));
232- // Get number of blocks per SM (assume we can fully utilize the SM).
233- int const numBlocksPerSM = std::max (1u , 2048u / block.x );
234- dim3 grid (std::min (int (m), multiProcessorCount * numBlocksPerSM));
235-
236- // Launch the cvt kernel.
237- auto * kernel_instance =
238- &silu_mul_quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_FP4, T, SF_VEC_SIZE,
239- false >;
240-
241- cudaLaunchConfig_t config;
242- config.gridDim = grid;
243- config.blockDim = block;
244- config.dynamicSmemBytes = 0 ;
245- config.stream = stream;
246- cudaLaunchAttribute attrs[1 ];
247- attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
248- attrs[0 ].val .programmaticStreamSerializationAllowed = enable_pdl;
249- config.numAttrs = 1 ;
250- config.attrs = attrs;
251- cudaLaunchKernelEx (&config, kernel_instance, b, m, n / 2 , n / 2 , input, SFScale,
252- reinterpret_cast <uint32_t *>(output), reinterpret_cast <uint32_t *>(SFOuput),
253- layout, mask);
254- }
255-
256225__global__ void block_scale_interleave_kernel (int numBatches, int numRows, int numRowsPadded,
257226 int numCols, int numColsPadded, uint8_t const * SFIn,
258227 uint8_t * SFOutput) {
@@ -325,57 +294,92 @@ void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn,
325294 block_scale_interleave_reverse_kernel<<<grid, block, 0 , stream>>> (b, m, n, SFIn, SFOutput);
326295}
327296
297+ template <typename T>
298+ void invokeSiluAndMulNVFP4Quantization (void * output, void * output_scale, void * input,
299+ void * input_global_scale, void * mask, bool use_silu_and_mul,
300+ int m_topk, int k, int n_experts, cudaStream_t stream) {
301+ int device;
302+ TLLM_CUDA_CHECK (cudaGetDevice (&device));
303+ int multiProcessorCount;
304+ TLLM_CUDA_CHECK (
305+ cudaDeviceGetAttribute (&multiProcessorCount, cudaDevAttrMultiProcessorCount, device));
306+
307+ // Grid, Block size.
308+ // Each thread converts 8 values.
309+ TLLM_CHECK_WITH_INFO (k > 0 , " k must be > 0" );
310+ int const workSizePerRow = max (1 , k / CVT_ELTS_PER_THREAD);
311+ int const totalWorkSize = m_topk * workSizePerRow;
312+ dim3 block (std::min (workSizePerRow, 512 ));
313+ // Get number of blocks per SM (assume we can fully utilize the SM).
314+ int const numBlocksPerSM = 2048 / block.x ;
315+ dim3 grid (std::min (static_cast <int >((totalWorkSize + block.x - 1 ) / block.x ),
316+ multiProcessorCount * numBlocksPerSM));
317+ while (grid.x <= multiProcessorCount && block.x > 64 ) {
318+ grid.x *= 2 ;
319+ block.x = (block.x + 1 ) / 2 ;
320+ }
321+
322+ // TODO(kaixih@nvidia): Should relax this to allow any grid size.
323+ // [email protected] : only deal with mask case 324+ TLLM_CHECK_WITH_INFO (mask != nullptr , " mask must be non-null for expert NVFP4 path" );
325+ TLLM_CHECK_WITH_INFO (n_experts > 0 , " n_experts must be > 0" );
326+ grid.x = (grid.x + n_experts - 1 ) / n_experts * n_experts;
327+ cvt_fp16_to_fp4_expert<T, false ><<<grid, block, 0 , stream>>> (
328+ m_topk, k, reinterpret_cast <T*>(input), reinterpret_cast <float *>(input_global_scale),
329+ reinterpret_cast <uint32_t *>(output), reinterpret_cast <uint32_t *>(output_scale),
330+ reinterpret_cast <int32_t *>(mask), use_silu_and_mul, n_experts);
331+ return ;
332+ }
333+
328334// Instantiate the function.
329335template void invokeFP4Quantization<half, 16 >(int b, int m, int n, half const * input,
330336 float const * SFScale, int64_t * output,
331337 int32_t * SFOuput, bool useUE8M0,
332338 QuantizationSFLayout layout, int multiProcessorCount,
333- int32_t const * mask, bool enable_pdl,
334- cudaStream_t stream);
339+ bool enable_pdl, cudaStream_t stream);
335340template void invokeFP4Quantization<half, 32 >(int b, int m, int n, half const * input,
336341 float const * SFScale, int64_t * output,
337342 int32_t * SFOuput, bool useUE8M0,
338343 QuantizationSFLayout layout, int multiProcessorCount,
339- int32_t const * mask, bool enable_pdl,
340- cudaStream_t stream);
344+ bool enable_pdl, cudaStream_t stream);
341345template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const * input,
342346 int64_t * output, int32_t * SFOuput,
343347 QuantizationSFLayout layout, int multiProcessorCount,
344348 bool enable_pdl, cudaStream_t stream);
345- template void invokeSiluAndMulFP4Quantization <half, 16 >(
346- int b, int m, int n, half const * input, float const * globalScale, int32_t const * mask,
347- int64_t * output, int32_t * SFOuput, QuantizationSFLayout layout , int multiProcessorCount ,
348- bool enable_pdl , cudaStream_t stream);
349+ template void invokeSiluAndMulNVFP4Quantization <half>( void * output, void * output_scale, void * input,
350+ void * input_global_scale, void * mask,
351+ bool use_silu_and_mul, int m_topk , int k ,
352+ int n_experts , cudaStream_t stream);
349353
350354#ifdef ENABLE_BF16
351355template void invokeFP4Quantization<__nv_bfloat16, 16 >(
352356 int b, int m, int n, __nv_bfloat16 const * input, float const * SFScale, int64_t * output,
353357 int32_t * SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
354- int32_t const * mask, bool enable_pdl, cudaStream_t stream);
358+ bool enable_pdl, cudaStream_t stream);
355359template void invokeFP4Quantization<__nv_bfloat16, 32 >(
356360 int b, int m, int n, __nv_bfloat16 const * input, float const * SFScale, int64_t * output,
357361 int32_t * SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
358- int32_t const * mask, bool enable_pdl, cudaStream_t stream);
362+ bool enable_pdl, cudaStream_t stream);
359363template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n,
360364 __nv_bfloat16 const * input, int64_t * output,
361365 int32_t * SFOuput, QuantizationSFLayout layout,
362366 int multiProcessorCount, bool enable_pdl,
363367 cudaStream_t stream);
364- template void invokeSiluAndMulFP4Quantization <__nv_bfloat16, 16 >(
365- int b, int m, int n, __nv_bfloat16 const * input, float const * globalScale, int32_t const * mask,
366- int64_t * output, int32_t * SFOuput, QuantizationSFLayout layout , int multiProcessorCount,
367- bool enable_pdl, cudaStream_t stream);
368+ template void invokeSiluAndMulNVFP4Quantization <__nv_bfloat16>(
369+ void * output, void * output_scale, void * input, void * input_global_scale, void * mask,
370+ bool use_silu_and_mul, int m_topk, int k , int n_experts, cudaStream_t stream);
371+
368372#endif
369373
370374#ifdef ENABLE_FP8
371375template void invokeFP4Quantization<__nv_fp8_e4m3, 16 >(
372376 int b, int m, int n, __nv_fp8_e4m3 const * input, float const * SFScale, int64_t * output,
373377 int32_t * SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
374- int32_t const * mask, bool enable_pdl, cudaStream_t stream);
378+ bool enable_pdl, cudaStream_t stream);
375379template void invokeFP4Quantization<__nv_fp8_e4m3, 32 >(
376380 int b, int m, int n, __nv_fp8_e4m3 const * input, float const * SFScale, int64_t * output,
377381 int32_t * SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
378- int32_t const * mask, bool enable_pdl, cudaStream_t stream);
382+ bool enable_pdl, cudaStream_t stream);
379383
380384#endif
381385
0 commit comments