@@ -194,9 +194,9 @@ void fp4_batched_quantize(Tensor self, Optional<Tensor> const& mask, Tensor glob
194194#undef LAUNCH_FP4_QUANTIZE_KERNEL
195195}
196196
197- void silu_and_mul_fp4_batched_quantize (Tensor const & self, Tensor const & mask,
198- Tensor const & globalScale, Tensor valueE2M1,
199- Tensor scaleFP8SF, int64_t sfVecSize) {
197+ void silu_and_mul_nvfp4_batched_quantize (Tensor const & self, Tensor const & mask,
198+ Tensor const & globalScale, Tensor valueE2M1,
199+ Tensor scaleFP8SF, int64_t sfVecSize) {
200200 // TODO(shuw): mask can be none
201201 CHECK_CUDA (self);
202202 CHECK_CONTIGUOUS (self);
@@ -225,18 +225,18 @@ void silu_and_mul_fp4_batched_quantize(Tensor const& self, Tensor const& mask,
225225 const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount ();
226226 auto layout = tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4;
227227
228- #define LAUNCH_SILU_AND_MUL_FP4_QUANTIZE_KERNEL (T, SF_VEC_SIZE ) \
228+ #define LAUNCH_SILU_AND_MUL_NVFP4_QUANTIZE_KERNEL (T, SF_VEC_SIZE ) \
229229 tensorrt_llm::kernels::invokeSiluAndMulFP4Quantization<T, SF_VEC_SIZE>( \
230230 b, m, k_by_2, reinterpret_cast <T*>(self->data ), static_cast <float *>(globalScale->data ), \
231231 static_cast <int32_t *>(mask->data ), reinterpret_cast <int64_t *>(valueE2M1->data ), \
232232 reinterpret_cast <int32_t *>(scaleFP8SF->data ), layout, mMultiProcessorCount , \
233233 get_stream (self->device ));
234234
235235 if (self->dtype == dl_float16) {
236- LAUNCH_SILU_AND_MUL_FP4_QUANTIZE_KERNEL (half, 16 )
236+ LAUNCH_SILU_AND_MUL_NVFP4_QUANTIZE_KERNEL (half, 16 )
237237 } else if (self->dtype == dl_bfloat16) {
238238#ifdef ENABLE_BF16
239- LAUNCH_SILU_AND_MUL_FP4_QUANTIZE_KERNEL (__nv_bfloat16, 16 )
239+ LAUNCH_SILU_AND_MUL_NVFP4_QUANTIZE_KERNEL (__nv_bfloat16, 16 )
240240#else
241241 TVM_FFI_LOG_AND_THROW (NotImplementedError)
242242 << " BFloat16 must be enabled to quantize an bf16 tensor to fp4." ;
@@ -246,9 +246,10 @@ void silu_and_mul_fp4_batched_quantize(Tensor const& self, Tensor const& mask,
246246 << " fp4_quantize only supports input tensor with dtypes fp16/bf16." ;
247247 }
248248
249- #undef LAUNCH_SILU_AND_MUL_FP4_QUANTIZE_KERNEL
249+ #undef LAUNCH_SILU_AND_MUL_NVFP4_QUANTIZE_KERNEL
250250}
251251
252252TVM_FFI_DLL_EXPORT_TYPED_FUNC (fp4_quantize, fp4_quantize);
253253TVM_FFI_DLL_EXPORT_TYPED_FUNC (fp4_batched_quantize, fp4_batched_quantize);
254- TVM_FFI_DLL_EXPORT_TYPED_FUNC (silu_and_mul_fp4_batched_quantize, silu_and_mul_fp4_batched_quantize);
254+ TVM_FFI_DLL_EXPORT_TYPED_FUNC (silu_and_mul_nvfp4_batched_quantize,
255+ silu_and_mul_nvfp4_batched_quantize);
0 commit comments