Skip to content

Commit 739df61

Browse files
authored
silu_and_mul nvfp4 quanization fusion rework (#1927)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR reverts #1774 and #1835 which have some issues with some shapes under cuda graph. The kernels ported in this PR comes from SGLANG. [[NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant op for the flashinfer grouped gemm](https://github.com/sgl-project/sglang/pull/9200/files) and [[NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant perf](https://github.com/sgl-project/sglang/pull/9556/files) by @kaixih . ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** - Added grouped FP4 quantization (scaled_fp4_grouped_quantize) and an NV-focused Silu+Mul expert quantization entry (silu_and_mul_scaled_nvfp4_experts_quantize). * **API Changes** - Replaced legacy batched APIs with new expert/grouped APIs; removed legacy mask parameter from FP4/MXFP8 quantization signatures and adjusted FP4 output layouts/types. * **Documentation** - Updated docs to list new functions and remove deprecated symbols. * **Tests** - Updated tests to validate new quantization paths, shapes, dtypes, and layouts. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Shu Wang. <[email protected]>
1 parent b73f04c commit 739df61

File tree

11 files changed

+595
-333
lines changed

11 files changed

+595
-333
lines changed

csrc/nv_internal/cpp/kernels/quantization.cu

Lines changed: 57 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
102102
&config,
103103
quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true>, b,
104104
m, n, padded_n, input, nullptr, reinterpret_cast<uint32_t*>(output),
105-
reinterpret_cast<uint32_t*>(SFOuput), layout, /*mask=*/nullptr);
105+
reinterpret_cast<uint32_t*>(SFOuput), layout);
106106
}
107107

108108
// Do per-token (row) quantization from fp16/bf16/fp32 to int8/fp8_e4m3.
@@ -164,11 +164,12 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3);
164164

165165
////////////////////////////////////////////////////////////////////////////////////////////////////
166166
// FP4/MXFP8 Quantization
167+
167168
template <typename T, int SF_VEC_SIZE>
168169
void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFScale,
169170
int64_t* output, int32_t* SFOuput, bool useUE8M0,
170-
QuantizationSFLayout layout, int multiProcessorCount,
171-
int32_t const* mask, bool enable_pdl, cudaStream_t stream) {
171+
QuantizationSFLayout layout, int multiProcessorCount, bool enable_pdl,
172+
cudaStream_t stream) {
172173
#ifdef ENABLE_FP8
173174
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
174175
// Grid, Block size.
@@ -186,7 +187,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
186187
T, SF_VEC_SIZE, false>;
187188
kernel_instance<<<grid, block, 0, stream>>>(b, m, n, n, input, SFScale,
188189
reinterpret_cast<uint32_t*>(output),
189-
reinterpret_cast<uint32_t*>(SFOuput), layout, mask);
190+
reinterpret_cast<uint32_t*>(SFOuput), layout);
190191

191192
} else
192193
#endif
@@ -217,42 +218,10 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
217218
config.attrs = attrs;
218219
cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale,
219220
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
220-
layout, mask);
221+
layout);
221222
}
222223
}
223224

224-
template <typename T, int SF_VEC_SIZE>
225-
void invokeSiluAndMulFP4Quantization(int b, int m, int n, T const* input, float const* SFScale,
226-
int32_t const* mask, int64_t* output, int32_t* SFOuput,
227-
QuantizationSFLayout layout, int multiProcessorCount,
228-
bool enable_pdl, cudaStream_t stream) {
229-
// Grid, Block size.
230-
// Each thread converts 8 values.
231-
dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512));
232-
// Get number of blocks per SM (assume we can fully utilize the SM).
233-
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
234-
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
235-
236-
// Launch the cvt kernel.
237-
auto* kernel_instance =
238-
&silu_mul_quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_FP4, T, SF_VEC_SIZE,
239-
false>;
240-
241-
cudaLaunchConfig_t config;
242-
config.gridDim = grid;
243-
config.blockDim = block;
244-
config.dynamicSmemBytes = 0;
245-
config.stream = stream;
246-
cudaLaunchAttribute attrs[1];
247-
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
248-
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
249-
config.numAttrs = 1;
250-
config.attrs = attrs;
251-
cudaLaunchKernelEx(&config, kernel_instance, b, m, n / 2, n / 2, input, SFScale,
252-
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
253-
layout, mask);
254-
}
255-
256225
__global__ void block_scale_interleave_kernel(int numBatches, int numRows, int numRowsPadded,
257226
int numCols, int numColsPadded, uint8_t const* SFIn,
258227
uint8_t* SFOutput) {
@@ -325,57 +294,92 @@ void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn,
325294
block_scale_interleave_reverse_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput);
326295
}
327296

297+
template <typename T>
298+
void invokeSiluAndMulNVFP4Quantization(void* output, void* output_scale, void* input,
299+
void* input_global_scale, void* mask, bool use_silu_and_mul,
300+
int m_topk, int k, int n_experts, cudaStream_t stream) {
301+
int device;
302+
TLLM_CUDA_CHECK(cudaGetDevice(&device));
303+
int multiProcessorCount;
304+
TLLM_CUDA_CHECK(
305+
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device));
306+
307+
// Grid, Block size.
308+
// Each thread converts 8 values.
309+
TLLM_CHECK_WITH_INFO(k > 0, "k must be > 0");
310+
int const workSizePerRow = max(1, k / CVT_ELTS_PER_THREAD);
311+
int const totalWorkSize = m_topk * workSizePerRow;
312+
dim3 block(std::min(workSizePerRow, 512));
313+
// Get number of blocks per SM (assume we can fully utilize the SM).
314+
int const numBlocksPerSM = 2048 / block.x;
315+
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x),
316+
multiProcessorCount * numBlocksPerSM));
317+
while (grid.x <= multiProcessorCount && block.x > 64) {
318+
grid.x *= 2;
319+
block.x = (block.x + 1) / 2;
320+
}
321+
322+
// TODO(kaixih@nvidia): Should relax this to allow any grid size.
323+
// [email protected]: only deal with mask case
324+
TLLM_CHECK_WITH_INFO(mask != nullptr, "mask must be non-null for expert NVFP4 path");
325+
TLLM_CHECK_WITH_INFO(n_experts > 0, "n_experts must be > 0");
326+
grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;
327+
cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(
328+
m_topk, k, reinterpret_cast<T*>(input), reinterpret_cast<float*>(input_global_scale),
329+
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(output_scale),
330+
reinterpret_cast<int32_t*>(mask), use_silu_and_mul, n_experts);
331+
return;
332+
}
333+
328334
// Instantiate the function.
329335
template void invokeFP4Quantization<half, 16>(int b, int m, int n, half const* input,
330336
float const* SFScale, int64_t* output,
331337
int32_t* SFOuput, bool useUE8M0,
332338
QuantizationSFLayout layout, int multiProcessorCount,
333-
int32_t const* mask, bool enable_pdl,
334-
cudaStream_t stream);
339+
bool enable_pdl, cudaStream_t stream);
335340
template void invokeFP4Quantization<half, 32>(int b, int m, int n, half const* input,
336341
float const* SFScale, int64_t* output,
337342
int32_t* SFOuput, bool useUE8M0,
338343
QuantizationSFLayout layout, int multiProcessorCount,
339-
int32_t const* mask, bool enable_pdl,
340-
cudaStream_t stream);
344+
bool enable_pdl, cudaStream_t stream);
341345
template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const* input,
342346
int64_t* output, int32_t* SFOuput,
343347
QuantizationSFLayout layout, int multiProcessorCount,
344348
bool enable_pdl, cudaStream_t stream);
345-
template void invokeSiluAndMulFP4Quantization<half, 16>(
346-
int b, int m, int n, half const* input, float const* globalScale, int32_t const* mask,
347-
int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount,
348-
bool enable_pdl, cudaStream_t stream);
349+
template void invokeSiluAndMulNVFP4Quantization<half>(void* output, void* output_scale, void* input,
350+
void* input_global_scale, void* mask,
351+
bool use_silu_and_mul, int m_topk, int k,
352+
int n_experts, cudaStream_t stream);
349353

350354
#ifdef ENABLE_BF16
351355
template void invokeFP4Quantization<__nv_bfloat16, 16>(
352356
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
353357
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
354-
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
358+
bool enable_pdl, cudaStream_t stream);
355359
template void invokeFP4Quantization<__nv_bfloat16, 32>(
356360
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
357361
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
358-
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
362+
bool enable_pdl, cudaStream_t stream);
359363
template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n,
360364
__nv_bfloat16 const* input, int64_t* output,
361365
int32_t* SFOuput, QuantizationSFLayout layout,
362366
int multiProcessorCount, bool enable_pdl,
363367
cudaStream_t stream);
364-
template void invokeSiluAndMulFP4Quantization<__nv_bfloat16, 16>(
365-
int b, int m, int n, __nv_bfloat16 const* input, float const* globalScale, int32_t const* mask,
366-
int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount,
367-
bool enable_pdl, cudaStream_t stream);
368+
template void invokeSiluAndMulNVFP4Quantization<__nv_bfloat16>(
369+
void* output, void* output_scale, void* input, void* input_global_scale, void* mask,
370+
bool use_silu_and_mul, int m_topk, int k, int n_experts, cudaStream_t stream);
371+
368372
#endif
369373

370374
#ifdef ENABLE_FP8
371375
template void invokeFP4Quantization<__nv_fp8_e4m3, 16>(
372376
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
373377
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
374-
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
378+
bool enable_pdl, cudaStream_t stream);
375379
template void invokeFP4Quantization<__nv_fp8_e4m3, 32>(
376380
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
377381
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
378-
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
382+
bool enable_pdl, cudaStream_t stream);
379383

380384
#endif
381385

0 commit comments

Comments
 (0)