Skip to content

Commit a3153fd

Browse files
committed
Revert "[Quantization] Add per-expert global scaling factor for fp4 batched quantize (#1835)"
This reverts commit f765a2a. Port kernels from sglang
1 parent bea5949 commit a3153fd

File tree

11 files changed

+1298
-347
lines changed

11 files changed

+1298
-347
lines changed

csrc/nv_internal/cpp/kernels/quantization.cu

Lines changed: 86 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "tensorrt_llm/common/quantTypeUtils.cuh"
2424
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
2525
#include "tensorrt_llm/kernels/quantization.cuh"
26+
// #include "tensorrt_llm/kernels/nvfp4_expert_quant.cuh"
2627
#include "tensorrt_llm/kernels/quantization.h"
2728

2829
using namespace tensorrt_llm::common;
@@ -102,7 +103,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
102103
&config,
103104
quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_MXFP8, T, SF_VEC_SIZE, true>, b,
104105
m, n, padded_n, input, nullptr, reinterpret_cast<uint32_t*>(output),
105-
reinterpret_cast<uint32_t*>(SFOuput), layout, /*mask=*/nullptr);
106+
reinterpret_cast<uint32_t*>(SFOuput), layout);
106107
}
107108

108109
// Do per-token (row) quantization from fp16/bf16/fp32 to int8/fp8_e4m3.
@@ -164,11 +165,12 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3);
164165

165166
////////////////////////////////////////////////////////////////////////////////////////////////////
166167
// FP4/MXFP8 Quantization
168+
167169
template <typename T, int SF_VEC_SIZE>
168170
void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFScale,
169171
int64_t* output, int32_t* SFOuput, bool useUE8M0,
170-
QuantizationSFLayout layout, int multiProcessorCount,
171-
int32_t const* mask, bool enable_pdl, cudaStream_t stream) {
172+
QuantizationSFLayout layout, int multiProcessorCount, bool enable_pdl,
173+
cudaStream_t stream) {
172174
#ifdef ENABLE_FP8
173175
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
174176
// Grid, Block size.
@@ -186,7 +188,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
186188
T, SF_VEC_SIZE, false>;
187189
kernel_instance<<<grid, block, 0, stream>>>(b, m, n, n, input, SFScale,
188190
reinterpret_cast<uint32_t*>(output),
189-
reinterpret_cast<uint32_t*>(SFOuput), layout, mask);
191+
reinterpret_cast<uint32_t*>(SFOuput), layout);
190192

191193
} else
192194
#endif
@@ -217,42 +219,10 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
217219
config.attrs = attrs;
218220
cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale,
219221
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
220-
layout, mask);
222+
layout);
221223
}
222224
}
223225

224-
template <typename T, int SF_VEC_SIZE>
225-
void invokeSiluAndMulFP4Quantization(int b, int m, int n, T const* input, float const* SFScale,
226-
int32_t const* mask, int64_t* output, int32_t* SFOuput,
227-
QuantizationSFLayout layout, int multiProcessorCount,
228-
bool enable_pdl, cudaStream_t stream) {
229-
// Grid, Block size.
230-
// Each thread converts 8 values.
231-
dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512));
232-
// Get number of blocks per SM (assume we can fully utilize the SM).
233-
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
234-
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
235-
236-
// Launch the cvt kernel.
237-
auto* kernel_instance =
238-
&silu_mul_quantize_with_block_size<BlockScaleQuantizationType::FP16_TO_FP4, T, SF_VEC_SIZE,
239-
false>;
240-
241-
cudaLaunchConfig_t config;
242-
config.gridDim = grid;
243-
config.blockDim = block;
244-
config.dynamicSmemBytes = 0;
245-
config.stream = stream;
246-
cudaLaunchAttribute attrs[1];
247-
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
248-
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
249-
config.numAttrs = 1;
250-
config.attrs = attrs;
251-
cudaLaunchKernelEx(&config, kernel_instance, b, m, n / 2, n / 2, input, SFScale,
252-
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
253-
layout, mask);
254-
}
255-
256226
__global__ void block_scale_interleave_kernel(int numBatches, int numRows, int numRowsPadded,
257227
int numCols, int numColsPadded, uint8_t const* SFIn,
258228
uint8_t* SFOutput) {
@@ -325,57 +295,120 @@ void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn,
325295
block_scale_interleave_reverse_kernel<<<grid, block, 0, stream>>>(b, m, n, SFIn, SFOutput);
326296
}
327297

298+
template <typename T>
299+
void invokeSiluAndMulNVFP4Quantization(void* output,
300+
void* output_scale,
301+
void* input,
302+
void* input_global_scale,
303+
void* input_offset_by_experts,
304+
void* output_scale_offset_by_experts,
305+
void* mask,
306+
bool use_silu_and_mul,
307+
int m_topk,
308+
int k,
309+
int n_experts,
310+
cudaStream_t stream) {
311+
int device;
312+
cudaGetDevice(&device);
313+
int multiProcessorCount;
314+
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);
315+
316+
// Grid, Block size.
317+
// Each thread converts 8 values.
318+
int const workSizePerRow = k / CVT_ELTS_PER_THREAD;
319+
int const totalWorkSize = m_topk * workSizePerRow;
320+
dim3 block(std::min(workSizePerRow, 512));
321+
// Get number of blocks per SM (assume we can fully utilize the SM).
322+
int const numBlocksPerSM = 2048 / block.x;
323+
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x), multiProcessorCount * numBlocksPerSM));
324+
while (grid.x <= multiProcessorCount && block.x > 64) {
325+
grid.x *= 2;
326+
block.x = (block.x + 1) / 2;
327+
}
328+
329+
// TODO(kaixih@nvidia): Should relax this to allow any grid size.
330+
// [email protected]: only deal with mask case
331+
assert(mask != nullptr);
332+
// if (mask != nullptr) {
333+
grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;
334+
cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(
335+
m_topk,
336+
k,
337+
reinterpret_cast<T*>(input),
338+
reinterpret_cast<float*>(input_global_scale),
339+
reinterpret_cast<uint32_t*>(output),
340+
reinterpret_cast<uint32_t*>(output_scale),
341+
reinterpret_cast<int32_t*>(mask),
342+
use_silu_and_mul,
343+
n_experts);
344+
return;
345+
// }
346+
}
347+
328348
// Instantiate the function.
329349
template void invokeFP4Quantization<half, 16>(int b, int m, int n, half const* input,
330350
float const* SFScale, int64_t* output,
331351
int32_t* SFOuput, bool useUE8M0,
332352
QuantizationSFLayout layout, int multiProcessorCount,
333-
int32_t const* mask, bool enable_pdl,
334-
cudaStream_t stream);
353+
bool enable_pdl, cudaStream_t stream);
335354
template void invokeFP4Quantization<half, 32>(int b, int m, int n, half const* input,
336355
float const* SFScale, int64_t* output,
337356
int32_t* SFOuput, bool useUE8M0,
338357
QuantizationSFLayout layout, int multiProcessorCount,
339-
int32_t const* mask, bool enable_pdl,
340-
cudaStream_t stream);
358+
bool enable_pdl, cudaStream_t stream);
341359
template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const* input,
342360
int64_t* output, int32_t* SFOuput,
343361
QuantizationSFLayout layout, int multiProcessorCount,
344362
bool enable_pdl, cudaStream_t stream);
345-
template void invokeSiluAndMulFP4Quantization<half, 16>(
346-
int b, int m, int n, half const* input, float const* globalScale, int32_t const* mask,
347-
int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount,
348-
bool enable_pdl, cudaStream_t stream);
363+
template void invokeSiluAndMulNVFP4Quantization<half>(void* output, void* output_scale,
364+
void* input,
365+
void* input_global_scale,
366+
void* input_offset_by_experts,
367+
void* output_scale_offset_by_experts,
368+
void* mask,
369+
bool use_silu_and_mul,
370+
int m_topk,
371+
int k,
372+
int n_experts,
373+
cudaStream_t stream);
349374

350375
#ifdef ENABLE_BF16
351376
template void invokeFP4Quantization<__nv_bfloat16, 16>(
352377
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
353378
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
354-
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
379+
bool enable_pdl, cudaStream_t stream);
355380
template void invokeFP4Quantization<__nv_bfloat16, 32>(
356381
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
357382
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
358-
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
383+
bool enable_pdl, cudaStream_t stream);
359384
template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n,
360385
__nv_bfloat16 const* input, int64_t* output,
361386
int32_t* SFOuput, QuantizationSFLayout layout,
362387
int multiProcessorCount, bool enable_pdl,
363388
cudaStream_t stream);
364-
template void invokeSiluAndMulFP4Quantization<__nv_bfloat16, 16>(
365-
int b, int m, int n, __nv_bfloat16 const* input, float const* globalScale, int32_t const* mask,
366-
int64_t* output, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount,
367-
bool enable_pdl, cudaStream_t stream);
389+
template void invokeSiluAndMulNVFP4Quantization<__nv_bfloat16>(void* output, void* output_scale,
390+
void* input,
391+
void* input_global_scale,
392+
void* input_offset_by_experts,
393+
void* output_scale_offset_by_experts,
394+
void* mask,
395+
bool use_silu_and_mul,
396+
int m_topk,
397+
int k,
398+
int n_experts,
399+
cudaStream_t stream);
400+
368401
#endif
369402

370403
#ifdef ENABLE_FP8
371404
template void invokeFP4Quantization<__nv_fp8_e4m3, 16>(
372405
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
373406
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
374-
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
407+
bool enable_pdl, cudaStream_t stream);
375408
template void invokeFP4Quantization<__nv_fp8_e4m3, 32>(
376409
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
377410
int32_t* SFOuput, bool useUE8M0, QuantizationSFLayout layout, int multiProcessorCount,
378-
int32_t const* mask, bool enable_pdl, cudaStream_t stream);
411+
bool enable_pdl, cudaStream_t stream);
379412

380413
#endif
381414

0 commit comments

Comments
 (0)