-
Notifications
You must be signed in to change notification settings - Fork 584
silu_and_mul nvfp4 quanization fusion rework #1927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
a3153fd
Revert "[Quantization] Add per-expert global scaling factor for fp4 bβ¦
wenscarl b979ee2
precommit
wenscarl c9f89ef
Add missing doc
wenscarl a2378d5
CleanUp
wenscarl 856f918
Address comments
wenscarl d972fbe
Address comments
wenscarl 48259af
Improve
wenscarl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -102,7 +102,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, | |
| &config, | ||
| quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true>, b, | ||
| m, n, padded_n, input, nullptr, reinterpret_cast<uint32_t*>(output), | ||
| reinterpret_cast<uint32_t*>(SFOuput), layout, /*mask=*/nullptr); | ||
| reinterpret_cast<uint32_t*>(SFOuput), layout); | ||
| } | ||
|
|
||
| // Do per-token (row) quantization from fp16/bf16/fp32 to int8/fp8_e4m3. | ||
|
|
@@ -164,11 +164,12 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3); | |
|
|
||
| //////////////////////////////////////////////////////////////////////////////////////////////////// | ||
| // FP4/MXFP8 Quantization | ||
|
|
||
| template <typename T, int SF_VEC_SIZE> | ||
| void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFScale, | ||
| int64_t* output, int32_t* SFOuput, bool useUE8M0, | ||
| QuantizationSFLayout layout, int multiProcessorCount, | ||
| int32_t const* mask, bool enable_pdl, cudaStream_t stream) { | ||
| QuantizationSFLayout layout, int multiProcessorCount, bool enable_pdl, | ||
| cudaStream_t stream) { | ||
| #ifdef ENABLE_FP8 | ||
| if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) { | ||
| // Grid, Block size. | ||
|
|
@@ -186,7 +187,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS | |
| T, SF_VEC_SIZE, false>; | ||
| kernel_instance<<<grid, block, 0, stream>>>(b, m, n, n, input, SFScale, | ||
| reinterpret_cast<uint32_t*>(output), | ||
| reinterpret_cast<uint32_t*>(SFOuput), layout, mask); | ||
| reinterpret_cast<uint32_t*>(SFOuput), layout); | ||
|
|
||
| } else | ||
| #endif | ||
|
|
@@ -217,42 +218,10 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS | |
| config.attrs = attrs; | ||
| cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale, | ||
| reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput), | ||
| layout, mask); | ||
| layout); | ||
| } | ||
| } | ||
|
|
||
| template <typename T, int SF_VEC_SIZE> | ||
| void invokeSiluAndMulFP4Quantization(int b, int m, int n, T const* input, float const* SFScale, | ||
| int32_t const* mask, int64_t* output, int32_t* SFOuput, | ||
| QuantizationSFLayout layout, int multiProcessorCount, | ||
| bool enable_pdl, cudaStream_t stream) { | ||
| // Grid, Block size. | ||
| // Each thread converts 8 values. | ||
| dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512)); | ||
| // Get number of blocks per SM (assume we can fully utilize the SM). | ||
| int const numBlocksPerSM = std::max(1u, 2048u / block.x); | ||
| dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); | ||
|
|
||
| // Launch the cvt kernel. | ||
| auto* kernel_instance = | ||
| &silu_mul_quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_FP4, T, SF_VEC_SIZE, | ||
| false>; | ||
|
|
||
| cudaLaunchConfig_t config; | ||
| config.gridDim = grid; | ||
| config.blockDim = block; | ||
| config.dynamicSmemBytes = 0; | ||
| config.stream = stream; | ||
| cudaLaunchAttribute attrs[1]; | ||
| attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; | ||
| attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; | ||
| config.numAttrs = 1; | ||
| config.attrs = attrs; | ||
| cudaLaunchKernelEx(&config, kernel_instance, b, m, n / 2, n / 2, input, SFScale, | ||
| reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput), | ||
| layout, mask); | ||
| } | ||
|
|
||
| __global__ void block_scale_interleave_kernel(int numBatches, int numRows, int numRowsPadded, | ||
| int numCols, int numColsPadded, uint8_t const* SFIn, | ||
| uint8_t* SFOutput) { | ||
|
|
@@ -325,57 +294,96 @@ void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn, | |
| block_scale_interleave_reverse_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput); | ||
| } | ||
|
|
||
| template <typename T> | ||
| void invokeSiluAndMulNVFP4Quantization(void* output, void* output_scale, void* input, | ||
| void* input_global_scale, void* input_offset_by_experts, | ||
| void* output_scale_offset_by_experts, void* mask, | ||
| bool use_silu_and_mul, int m_topk, int k, int n_experts, | ||
| cudaStream_t stream) { | ||
| int device; | ||
| cudaGetDevice(&device); | ||
| int multiProcessorCount; | ||
| cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device); | ||
|
|
||
| // Grid, Block size. | ||
| // Each thread converts 8 values. | ||
| int const workSizePerRow = k / CVT_ELTS_PER_THREAD; | ||
| int const totalWorkSize = m_topk * workSizePerRow; | ||
| dim3 block(std::min(workSizePerRow, 512)); | ||
| // Get number of blocks per SM (assume we can fully utilize the SM). | ||
| int const numBlocksPerSM = 2048 / block.x; | ||
| dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x), | ||
| multiProcessorCount * numBlocksPerSM)); | ||
| while (grid.x <= multiProcessorCount && block.x > 64) { | ||
| grid.x *= 2; | ||
| block.x = (block.x + 1) / 2; | ||
| } | ||
|
|
||
| // TODO(kaixih@nvidia): Should relax this to allow any grid size. | ||
| // [email protected]: only deal with mask case | ||
| assert(mask != nullptr); | ||
| // if (mask != nullptr) { | ||
| grid.x = (grid.x + n_experts - 1) / n_experts * n_experts; | ||
| cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>( | ||
| m_topk, k, reinterpret_cast<T*>(input), reinterpret_cast<float*>(input_global_scale), | ||
| reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(output_scale), | ||
| reinterpret_cast<int32_t*>(mask), use_silu_and_mul, n_experts); | ||
| return; | ||
| // } | ||
wenscarl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| // Instantiate the function. | ||
| template void invokeFP4Quantization<half, 16>(int b, int m, int n, half const* input, | ||
| float const* SFScale, int64_t* output, | ||
| int32_t* SFOuput, bool useUE8M0, | ||
| QuantizationSFLayout layout, int multiProcessorCount, | ||
| int32_t const* mask, bool enable_pdl, | ||
| cudaStream_t stream); | ||
| bool enable_pdl, cudaStream_t stream); | ||
| template void invokeFP4Quantization<half, 32>(int b, int m, int n, half const* input, | ||
| float const* SFScale, int64_t* output, | ||
| int32_t* SFOuput, bool useUE8M0, | ||
| QuantizationSFLayout layout, int multiProcessorCount, | ||
| int32_t const* mask, bool enable_pdl, | ||
| cudaStream_t stream); | ||
| bool enable_pdl, cudaStream_t stream); | ||
| template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const* input, | ||
| int64_t* output, int32_t* SFOuput, | ||
| QuantizationSFLayout layout, int multiProcessorCount, | ||
| bool enable_pdl, cudaStream_t stream); | ||
| template void invokeSiluAndMulFP4Quantization<half, 16>( | ||
| int b, int m, int n, half const* input, float const* globalScale, int32_t const* mask, | ||
| int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount, | ||
| bool enable_pdl, cudaStream_t stream); | ||
| template void invokeSiluAndMulNVFP4Quantization<half>(void* output, void* output_scale, void* input, | ||
| void* input_global_scale, | ||
| void* input_offset_by_experts, | ||
| void* output_scale_offset_by_experts, | ||
| void* mask, bool use_silu_and_mul, int m_topk, | ||
| int k, int n_experts, cudaStream_t stream); | ||
|
|
||
| #ifdef ENABLE_BF16 | ||
| template void invokeFP4Quantization<__nv_bfloat16, 16>( | ||
| int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output, | ||
| int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount, | ||
| int32_t const* mask, bool enable_pdl, cudaStream_t stream); | ||
| bool enable_pdl, cudaStream_t stream); | ||
| template void invokeFP4Quantization<__nv_bfloat16, 32>( | ||
| int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output, | ||
| int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount, | ||
| int32_t const* mask, bool enable_pdl, cudaStream_t stream); | ||
| bool enable_pdl, cudaStream_t stream); | ||
| template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n, | ||
| __nv_bfloat16 const* input, int64_t* output, | ||
| int32_t* SFOuput, QuantizationSFLayout layout, | ||
| int multiProcessorCount, bool enable_pdl, | ||
| cudaStream_t stream); | ||
| template void invokeSiluAndMulFP4Quantization<__nv_bfloat16, 16>( | ||
| int b, int m, int n, __nv_bfloat16 const* input, float const* globalScale, int32_t const* mask, | ||
| int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount, | ||
| bool enable_pdl, cudaStream_t stream); | ||
| template void invokeSiluAndMulNVFP4Quantization<__nv_bfloat16>( | ||
| void* output, void* output_scale, void* input, void* input_global_scale, | ||
| void* input_offset_by_experts, void* output_scale_offset_by_experts, void* mask, | ||
| bool use_silu_and_mul, int m_topk, int k, int n_experts, cudaStream_t stream); | ||
|
|
||
| #endif | ||
|
|
||
| #ifdef ENABLE_FP8 | ||
| template void invokeFP4Quantization<__nv_fp8_e4m3, 16>( | ||
| int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output, | ||
| int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount, | ||
| int32_t const* mask, bool enable_pdl, cudaStream_t stream); | ||
| bool enable_pdl, cudaStream_t stream); | ||
| template void invokeFP4Quantization<__nv_fp8_e4m3, 32>( | ||
| int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output, | ||
| int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount, | ||
| int32_t const* mask, bool enable_pdl, cudaStream_t stream); | ||
| bool enable_pdl, cudaStream_t stream); | ||
|
|
||
| #endif | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.