@@ -50,7 +50,8 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
5050 // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
5151 // reciprocal(SFScaleVal))
5252 float outputScale =
53- SFValue != 0 ? reciprocal_approximate_ftz (SFValue * reciprocal_approximate_ftz (SFScaleVal)) : 0 .0f ;
53+ SFValue != 0 ? reciprocal_approximate_ftz (SFValue * reciprocal_approximate_ftz (SFScaleVal))
54+ : 0 .0f ;
5455
5556 if (SFout) {
5657 // Write the SF to global memory (STG.8).
@@ -81,9 +82,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
8182#endif
8283}
8384
84- __device__ __forceinline__ float silu (const float & val) {
85- return val / (1 .0f + __expf (-val));
86- }
85+ __device__ __forceinline__ float silu (const float & val) { return val / (1 .0f + __expf (-val)); }
8786
8887template <class Type >
8988inline __device__ void silu_and_mul (PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
@@ -116,21 +115,14 @@ __launch_bounds__(512, 4) cvt_fp16_to_fp4(
116115#else
117116cvt_fp16_to_fp4 (
118117#endif
119- int32_t numRows,
120- int32_t numCols,
121- Type const * in,
122- float const * SFScale,
123- uint32_t * out,
124- uint32_t * SFout,
125- uint32_t * input_offset_by_experts,
126- uint32_t * output_scale_offset_by_experts,
127- int32_t * mask,
128- int n_experts,
129- bool low_latency) {
118+ int32_t numRows, int32_t numCols, Type const * in, float const * SFScale, uint32_t * out,
119+ uint32_t * SFout, uint32_t * input_offset_by_experts, uint32_t * output_scale_offset_by_experts,
120+ int32_t * mask, int n_experts, bool low_latency) {
130121#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
131122 using PackedVec = PackedVec<Type>;
132123 static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
133- static_assert (sizeof (PackedVec) == sizeof (Type) * CVT_FP4_ELTS_PER_THREAD, " Vec size is not matched." );
124+ static_assert (sizeof (PackedVec) == sizeof (Type) * CVT_FP4_ELTS_PER_THREAD,
125+ " Vec size is not matched." );
134126
135127 // Input tensor row/col loops.
136128 int tid = blockIdx .x * blockDim .x + threadIdx .x ;
@@ -233,19 +225,13 @@ __launch_bounds__(512, 4) cvt_fp16_to_fp4_expert(
233225#else
234226cvt_fp16_to_fp4_expert (
235227#endif
236- int32_t numRows,
237- int32_t numCols,
238- Type const * in,
239- float const * SFScale,
240- uint32_t * out,
241- uint32_t * SFout,
242- int32_t * mask,
243- bool use_silu_and_mul,
244- int n_experts) {
228+ int32_t numRows, int32_t numCols, Type const * in, float const * SFScale, uint32_t * out,
229+ uint32_t * SFout, int32_t * mask, bool use_silu_and_mul, int n_experts) {
245230#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
246231 using PackedVec = PackedVec<Type>;
247232 static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
248- static_assert (sizeof (PackedVec) == sizeof (Type) * CVT_FP4_ELTS_PER_THREAD, " Vec size is not matched." );
233+ static_assert (sizeof (PackedVec) == sizeof (Type) * CVT_FP4_ELTS_PER_THREAD,
234+ " Vec size is not matched." );
249235
250236 // Input tensor row/col loops.
251237 int tid = blockIdx .x * blockDim .x + threadIdx .x ;
@@ -281,8 +267,8 @@ cvt_fp16_to_fp4_expert(
281267 int actualColsPerRow = use_silu_and_mul ? colsPerRow * 2 : colsPerRow;
282268
283269 // Each global thread processes one element
284- for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow; globalIdx < (expert_idx + 1 ) * m * colsPerRow;
285- globalIdx += actual_stride) {
270+ for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow;
271+ globalIdx < (expert_idx + 1 ) * m * colsPerRow; globalIdx += actual_stride) {
286272 // Calculate which row and column this global thread should process
287273 int rowIdx = globalIdx / colsPerRow;
288274 int colIdx = globalIdx % colsPerRow;
@@ -347,9 +333,10 @@ cvt_fp16_to_fp4_expert(
347333// int n_experts) {
348334// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
349335// using PackedVec = PackedVec<Type>;
350- // static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
351- // static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
352- // extern __shared__ uint32_t shared_input_offsets[];
336+ // static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE /
337+ // CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) *
338+ // CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); extern __shared__ uint32_t
339+ // shared_input_offsets[];
353340
354341// // Load input offsets into shared memory.
355342// // If n_experts is larger than 4, use vectorized int4 to save instructions.
@@ -360,7 +347,8 @@ cvt_fp16_to_fp4_expert(
360347// }
361348// } else {
362349// for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) {
363- // *reinterpret_cast<int4*>(&shared_input_offsets[i]) = *reinterpret_cast<const int4*>(&input_offset_by_experts[i]);
350+ // *reinterpret_cast<int4*>(&shared_input_offsets[i]) = *reinterpret_cast<const
351+ // int4*>(&input_offset_by_experts[i]);
364352// }
365353// if (threadIdx.x == 0) {
366354// shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];
@@ -375,7 +363,8 @@ cvt_fp16_to_fp4_expert(
375363// int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow;
376364
377365// // Each global thread processes one element
378- // for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) {
366+ // for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x *
367+ // blockDim.x) {
379368// // Calculate which row and column this global thread should process
380369// int rowIdx = globalIdx / colsPerRow;
381370// int colIdx = globalIdx % colsPerRow;
@@ -424,7 +413,8 @@ cvt_fp16_to_fp4_expert(
424413// int factor = CVT_FP4_SF_VEC_SIZE * 4;
425414// int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
426415// int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
427- // uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
416+ // uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] *
417+ // numCols_SFout;
428418
429419// auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
430420// rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
@@ -461,14 +451,16 @@ constexpr auto UINT8 = at::ScalarType::Byte;
461451// torch::Tensor const& input_offset_by_experts,
462452// torch::Tensor const& output_scale_offset_by_experts) {
463453// auto sm_version = getSMVersion();
464- // TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on sm100a/sm103a");
454+ // TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on
455+ // sm100a/sm103a");
465456
466457// CHECK_INPUT(output, "output must be a CUDA tensor");
467458// CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
468459// CHECK_INPUT(input, "input must be a CUDA tensor");
469460// CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
470461// CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor");
471- // CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor");
462+ // CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA
463+ // tensor");
472464
473465// TORCH_CHECK(output.dim() == 2);
474466// TORCH_CHECK(output_scale.dim() == 2);
@@ -545,7 +537,8 @@ constexpr auto UINT8 = at::ScalarType::Byte;
545537// torch::Tensor const& mask,
546538// bool use_silu_and_mul) {
547539// auto sm_version = getSMVersion();
548- // TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on sm100a/sm103a");
540+ // TORCH_CHECK(sm_version == 100 || sm_version == 103, "fp4_quant is only supported on
541+ // sm100a/sm103a");
549542
550543// CHECK_INPUT(output, "output must be a CUDA tensor");
551544// CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
@@ -618,4 +611,4 @@ constexpr auto UINT8 = at::ScalarType::Byte;
618611// } else {
619612// TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
620613// }
621- // }
614+ // }
0 commit comments