Skip to content

Commit b979ee2

Browse files
committed
precommit
1 parent a3153fd commit b979ee2

File tree

10 files changed

+119
-182
lines changed

10 files changed

+119
-182
lines changed

csrc/nv_internal/cpp/kernels/quantization.cu

Lines changed: 24 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -296,18 +296,11 @@ void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn,
296296
}
297297

298298
template <typename T>
299-
void invokeSiluAndMulNVFP4Quantization(void* output,
300-
void* output_scale,
301-
void* input,
302-
void* input_global_scale,
303-
void* input_offset_by_experts,
304-
void* output_scale_offset_by_experts,
305-
void* mask,
306-
bool use_silu_and_mul,
307-
int m_topk,
308-
int k,
309-
int n_experts,
310-
cudaStream_t stream) {
299+
void invokeSiluAndMulNVFP4Quantization(void* output, void* output_scale, void* input,
300+
void* input_global_scale, void* input_offset_by_experts,
301+
void* output_scale_offset_by_experts, void* mask,
302+
bool use_silu_and_mul, int m_topk, int k, int n_experts,
303+
cudaStream_t stream) {
311304
int device;
312305
cudaGetDevice(&device);
313306
int multiProcessorCount;
@@ -320,7 +313,8 @@ void invokeSiluAndMulNVFP4Quantization(void* output,
320313
dim3 block(std::min(workSizePerRow, 512));
321314
// Get number of blocks per SM (assume we can fully utilize the SM).
322315
int const numBlocksPerSM = 2048 / block.x;
323-
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x), multiProcessorCount * numBlocksPerSM));
316+
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x),
317+
multiProcessorCount * numBlocksPerSM));
324318
while (grid.x <= multiProcessorCount && block.x > 64) {
325319
grid.x *= 2;
326320
block.x = (block.x + 1) / 2;
@@ -330,19 +324,13 @@ void invokeSiluAndMulNVFP4Quantization(void* output,
330324
// [email protected]: only deal with mask case
331325
assert(mask != nullptr);
332326
// if (mask != nullptr) {
333-
grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;
334-
cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(
335-
m_topk,
336-
k,
337-
reinterpret_cast<T*>(input),
338-
reinterpret_cast<float*>(input_global_scale),
339-
reinterpret_cast<uint32_t*>(output),
340-
reinterpret_cast<uint32_t*>(output_scale),
341-
reinterpret_cast<int32_t*>(mask),
342-
use_silu_and_mul,
343-
n_experts);
344-
return;
345-
// }
327+
grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;
328+
cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(
329+
m_topk, k, reinterpret_cast<T*>(input), reinterpret_cast<float*>(input_global_scale),
330+
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(output_scale),
331+
reinterpret_cast<int32_t*>(mask), use_silu_and_mul, n_experts);
332+
return;
333+
// }
346334
}
347335

348336
// Instantiate the function.
@@ -360,17 +348,12 @@ template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, h
360348
int64_t* output, int32_t* SFOuput,
361349
QuantizationSFLayout layout, int multiProcessorCount,
362350
bool enable_pdl, cudaStream_t stream);
363-
template void invokeSiluAndMulNVFP4Quantization<half>(void* output, void* output_scale,
364-
void* input,
365-
void* input_global_scale,
366-
void* input_offset_by_experts,
367-
void* output_scale_offset_by_experts,
368-
void* mask,
369-
bool use_silu_and_mul,
370-
int m_topk,
371-
int k,
372-
int n_experts,
373-
cudaStream_t stream);
351+
template void invokeSiluAndMulNVFP4Quantization<half>(void* output, void* output_scale, void* input,
352+
void* input_global_scale,
353+
void* input_offset_by_experts,
354+
void* output_scale_offset_by_experts,
355+
void* mask, bool use_silu_and_mul, int m_topk,
356+
int k, int n_experts, cudaStream_t stream);
374357

375358
#ifdef ENABLE_BF16
376359
template void invokeFP4Quantization<__nv_bfloat16, 16>(
@@ -386,17 +369,10 @@ template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int pa
386369
int32_t* SFOuput, QuantizationSFLayout layout,
387370
int multiProcessorCount, bool enable_pdl,
388371
cudaStream_t stream);
389-
template void invokeSiluAndMulNVFP4Quantization<__nv_bfloat16>(void* output, void* output_scale,
390-
void* input,
391-
void* input_global_scale,
392-
void* input_offset_by_experts,
393-
void* output_scale_offset_by_experts,
394-
void* mask,
395-
bool use_silu_and_mul,
396-
int m_topk,
397-
int k,
398-
int n_experts,
399-
cudaStream_t stream);
372+
template void invokeSiluAndMulNVFP4Quantization<__nv_bfloat16>(
373+
void* output, void* output_scale, void* input, void* input_global_scale,
374+
void* input_offset_by_experts, void* output_scale_offset_by_experts, void* mask,
375+
bool use_silu_and_mul, int m_topk, int k, int n_experts, cudaStream_t stream);
400376

401377
#endif
402378

csrc/nv_internal/tensorrt_llm/kernels/nvfp4_expert_quant.cuh

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
5050
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
5151
// reciprocal(SFScaleVal))
5252
float outputScale =
53-
SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;
53+
SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal))
54+
: 0.0f;
5455

5556
if (SFout) {
5657
// Write the SF to global memory (STG.8).
@@ -81,9 +82,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
8182
#endif
8283
}
8384

84-
__device__ __forceinline__ float silu(const float& val) {
85-
return val / (1.0f + __expf(-val));
86-
}
85+
__device__ __forceinline__ float silu(const float& val) { return val / (1.0f + __expf(-val)); }
8786

8887
template <class Type>
8988
inline __device__ void silu_and_mul(PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
@@ -116,21 +115,14 @@ __launch_bounds__(512, 4) cvt_fp16_to_fp4(
116115
#else
117116
cvt_fp16_to_fp4(
118117
#endif
119-
int32_t numRows,
120-
int32_t numCols,
121-
Type const* in,
122-
float const* SFScale,
123-
uint32_t* out,
124-
uint32_t* SFout,
125-
uint32_t* input_offset_by_experts,
126-
uint32_t* output_scale_offset_by_experts,
127-
int32_t* mask,
128-
int n_experts,
129-
bool low_latency) {
118+
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out,
119+
uint32_t* SFout, uint32_t* input_offset_by_experts, uint32_t* output_scale_offset_by_experts,
120+
int32_t* mask, int n_experts, bool low_latency) {
130121
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
131122
using PackedVec = PackedVec<Type>;
132123
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
133-
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
124+
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
125+
"Vec size is not matched.");
134126

135127
// Input tensor row/col loops.
136128
int tid = blockIdx.x * blockDim.x + threadIdx.x;
@@ -233,19 +225,13 @@ __launch_bounds__(512, 4) cvt_fp16_to_fp4_expert(
233225
#else
234226
cvt_fp16_to_fp4_expert(
235227
#endif
236-
int32_t numRows,
237-
int32_t numCols,
238-
Type const* in,
239-
float const* SFScale,
240-
uint32_t* out,
241-
uint32_t* SFout,
242-
int32_t* mask,
243-
bool use_silu_and_mul,
244-
int n_experts) {
228+
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out,
229+
uint32_t* SFout, int32_t* mask, bool use_silu_and_mul, int n_experts) {
245230
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
246231
using PackedVec = PackedVec<Type>;
247232
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
248-
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
233+
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
234+
"Vec size is not matched.");
249235

250236
// Input tensor row/col loops.
251237
int tid = blockIdx.x * blockDim.x + threadIdx.x;
@@ -281,8 +267,8 @@ cvt_fp16_to_fp4_expert(
281267
int actualColsPerRow = use_silu_and_mul ? colsPerRow * 2 : colsPerRow;
282268

283269
// Each global thread processes one element
284-
for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow; globalIdx < (expert_idx + 1) * m * colsPerRow;
285-
globalIdx += actual_stride) {
270+
for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow;
271+
globalIdx < (expert_idx + 1) * m * colsPerRow; globalIdx += actual_stride) {
286272
// Calculate which row and column this global thread should process
287273
int rowIdx = globalIdx / colsPerRow;
288274
int colIdx = globalIdx % colsPerRow;
@@ -347,9 +333,10 @@ cvt_fp16_to_fp4_expert(
347333
// int n_experts) {
348334
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
349335
// using PackedVec = PackedVec<Type>;
350-
// static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
351-
// static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
352-
// extern __shared__ uint32_t shared_input_offsets[];
336+
// static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE /
337+
// CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) *
338+
// CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); extern __shared__ uint32_t
339+
// shared_input_offsets[];
353340

354341
// // Load input offsets into shared memory.
355342
// // If n_experts is larger than 4, use vectorized int4 to save instructions.
@@ -360,7 +347,8 @@ cvt_fp16_to_fp4_expert(
360347
// }
361348
// } else {
362349
// for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) {
363-
// *reinterpret_cast<int4*>(&shared_input_offsets[i]) = *reinterpret_cast<const int4*>(&input_offset_by_experts[i]);
350+
// *reinterpret_cast<int4*>(&shared_input_offsets[i]) = *reinterpret_cast<const
351+
// int4*>(&input_offset_by_experts[i]);
364352
// }
365353
// if (threadIdx.x == 0) {
366354
// shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];
@@ -375,7 +363,8 @@ cvt_fp16_to_fp4_expert(
375363
// int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow;
376364

377365
// // Each global thread processes one element
378-
// for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) {
366+
// for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x *
367+
// blockDim.x) {
379368
// // Calculate which row and column this global thread should process
380369
// int rowIdx = globalIdx / colsPerRow;
381370
// int colIdx = globalIdx % colsPerRow;
@@ -424,7 +413,8 @@ cvt_fp16_to_fp4_expert(
424413
// int factor = CVT_FP4_SF_VEC_SIZE * 4;
425414
// int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
426415
// int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
427-
// uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
416+
// uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] *
417+
// numCols_SFout;
428418

429419
// auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
430420
// rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
@@ -461,14 +451,16 @@ constexpr auto UINT8 = at::ScalarType::Byte;
461451
// torch::Tensor const& input_offset_by_experts,
462452
// torch::Tensor const& output_scale_offset_by_experts) {
463453
// auto sm_version = getSMVersion();
464-
// TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on sm100a/sm103a");
454+
// TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on
455+
// sm100a/sm103a");
465456

466457
// CHECK_INPUT(output, "output must be a CUDA tensor");
467458
// CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
468459
// CHECK_INPUT(input, "input must be a CUDA tensor");
469460
// CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
470461
// CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor");
471-
// CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor");
462+
// CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA
463+
// tensor");
472464

473465
// TORCH_CHECK(output.dim() == 2);
474466
// TORCH_CHECK(output_scale.dim() == 2);
@@ -545,7 +537,8 @@ constexpr auto UINT8 = at::ScalarType::Byte;
545537
// torch::Tensor const& mask,
546538
// bool use_silu_and_mul) {
547539
// auto sm_version = getSMVersion();
548-
// TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on sm100a/sm103a");
540+
// TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on
541+
// sm100a/sm103a");
549542

550543
// CHECK_INPUT(output, "output must be a CUDA tensor");
551544
// CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
@@ -618,4 +611,4 @@ constexpr auto UINT8 = at::ScalarType::Byte;
618611
// } else {
619612
// TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
620613
// }
621-
// }
614+
// }

csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,8 @@ quantize_with_block_size(
838838
}
839839

840840
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
841-
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) {
841+
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols,
842+
SFType* SFout) {
842843
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
843844
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2);
844845

@@ -882,9 +883,7 @@ __device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, i
882883
return nullptr;
883884
}
884885

885-
__device__ __forceinline__ float silu(const float& val) {
886-
return val / (1.0f + __expf(-val));
887-
}
886+
__device__ __forceinline__ float silu(const float& val) { return val / (1.0f + __expf(-val)); }
888887

889888
template <class Type>
890889
inline __device__ void silu_and_mul(PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
@@ -917,19 +916,13 @@ __launch_bounds__(512, 4) cvt_fp16_to_fp4_expert(
917916
#else
918917
cvt_fp16_to_fp4_expert(
919918
#endif
920-
int32_t numRows,
921-
int32_t numCols,
922-
Type const* in,
923-
float const* SFScale,
924-
uint32_t* out,
925-
uint32_t* SFout,
926-
int32_t* mask,
927-
bool use_silu_and_mul,
928-
int n_experts) {
919+
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out,
920+
uint32_t* SFout, int32_t* mask, bool use_silu_and_mul, int n_experts) {
929921
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
930922
using PackedVec = PackedVec<Type>;
931923
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
932-
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
924+
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
925+
"Vec size is not matched.");
933926

934927
// Input tensor row/col loops.
935928
int tid = blockIdx.x * blockDim.x + threadIdx.x;
@@ -965,8 +958,8 @@ cvt_fp16_to_fp4_expert(
965958
int actualColsPerRow = use_silu_and_mul ? colsPerRow * 2 : colsPerRow;
966959

967960
// Each global thread processes one element
968-
for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow; globalIdx < (expert_idx + 1) * m * colsPerRow;
969-
globalIdx += actual_stride) {
961+
for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow;
962+
globalIdx < (expert_idx + 1) * m * colsPerRow; globalIdx += actual_stride) {
970963
// Calculate which row and column this global thread should process
971964
int rowIdx = globalIdx / colsPerRow;
972965
int colIdx = globalIdx % colsPerRow;

csrc/nv_internal/tensorrt_llm/kernels/quantization.h

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,11 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* glo
6363
bool enable_pdl = false, cudaStream_t stream = 0);
6464

6565
template <typename T>
66-
void invokeSiluAndMulNVFP4Quantization(void* output,
67-
void* output_scale,
68-
void* input,
69-
void* input_global_scale,
70-
void* input_offset_by_experts,
71-
void* output_scale_offset_by_experts,
72-
void* mask,
73-
bool use_silu_and_mul,
74-
int m_topk,
75-
int k,
76-
int n_experts,
77-
cudaStream_t stream);
78-
66+
void invokeSiluAndMulNVFP4Quantization(void* output, void* output_scale, void* input,
67+
void* input_global_scale, void* input_offset_by_experts,
68+
void* output_scale_offset_by_experts, void* mask,
69+
bool use_silu_and_mul, int m_topk, int k, int n_experts,
70+
cudaStream_t stream);
7971

8072
void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded,
8173
uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount,

0 commit comments

Comments
 (0)