Skip to content

Commit c7eaba6

Browse files
committed
Address func name
1 parent bd01434 commit c7eaba6

File tree

7 files changed

+26
-25
lines changed

7 files changed

+26
-25
lines changed

csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,9 @@ void fp4_batched_quantize(Tensor self, Optional<Tensor> const& mask, Tensor glob
194194
#undef LAUNCH_FP4_QUANTIZE_KERNEL
195195
}
196196

197-
void silu_and_mul_fp4_batched_quantize(Tensor const& self, Tensor const& mask,
198-
Tensor const& globalScale, Tensor valueE2M1,
199-
Tensor scaleFP8SF, int64_t sfVecSize) {
197+
void silu_and_mul_nvfp4_batched_quantize(Tensor const& self, Tensor const& mask,
198+
Tensor const& globalScale, Tensor valueE2M1,
199+
Tensor scaleFP8SF, int64_t sfVecSize) {
200200
// TODO(shuw): mask can be none
201201
CHECK_CUDA(self);
202202
CHECK_CONTIGUOUS(self);
@@ -225,18 +225,18 @@ void silu_and_mul_fp4_batched_quantize(Tensor const& self, Tensor const& mask,
225225
const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
226226
auto layout = tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4;
227227

228-
#define LAUNCH_SILU_AND_MUL_FP4_QUANTIZE_KERNEL(T, SF_VEC_SIZE) \
228+
#define LAUNCH_SILU_AND_MUL_NVFP4_QUANTIZE_KERNEL(T, SF_VEC_SIZE) \
229229
tensorrt_llm::kernels::invokeSiluAndMulFP4Quantization<T, SF_VEC_SIZE>( \
230230
b, m, k_by_2, reinterpret_cast<T*>(self->data), static_cast<float*>(globalScale->data), \
231231
static_cast<int32_t*>(mask->data), reinterpret_cast<int64_t*>(valueE2M1->data), \
232232
reinterpret_cast<int32_t*>(scaleFP8SF->data), layout, mMultiProcessorCount, \
233233
get_stream(self->device));
234234

235235
if (self->dtype == dl_float16) {
236-
LAUNCH_SILU_AND_MUL_FP4_QUANTIZE_KERNEL(half, 16)
236+
LAUNCH_SILU_AND_MUL_NVFP4_QUANTIZE_KERNEL(half, 16)
237237
} else if (self->dtype == dl_bfloat16) {
238238
#ifdef ENABLE_BF16
239-
LAUNCH_SILU_AND_MUL_FP4_QUANTIZE_KERNEL(__nv_bfloat16, 16)
239+
LAUNCH_SILU_AND_MUL_NVFP4_QUANTIZE_KERNEL(__nv_bfloat16, 16)
240240
#else
241241
TVM_FFI_LOG_AND_THROW(NotImplementedError)
242242
<< "BFloat16 must be enabled to quantize an bf16 tensor to fp4.";
@@ -246,9 +246,10 @@ void silu_and_mul_fp4_batched_quantize(Tensor const& self, Tensor const& mask,
246246
<< "fp4_quantize only supports input tensor with dtypes fp16/bf16.";
247247
}
248248

249-
#undef LAUNCH_SILU_AND_MUL_FP4_QUANTIZE_KERNEL
249+
#undef LAUNCH_SILU_AND_MUL_NVFP4_QUANTIZE_KERNEL
250250
}
251251

252252
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp4_quantize, fp4_quantize);
253253
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp4_batched_quantize, fp4_batched_quantize);
254-
TVM_FFI_DLL_EXPORT_TYPED_FUNC(silu_and_mul_fp4_batched_quantize, silu_and_mul_fp4_batched_quantize);
254+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(silu_and_mul_nvfp4_batched_quantize,
255+
silu_and_mul_nvfp4_batched_quantize);

csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@ void fp4_quantize(Tensor self, Optional<Tensor> const& globalScale, Tensor value
3434
void fp4_batched_quantize(Tensor self, Optional<Tensor> const& mask, Tensor globalScale,
3535
Tensor valueE2M1, Tensor scaleFP8SF, int64_t sfVecSize, bool sfUseUE8M0);
3636

37-
void silu_and_mul_fp4_batched_quantize(Tensor const& self, Tensor const& mask,
38-
Tensor const& globalScale, Tensor valueE2M1,
39-
Tensor scaleFP8SF, int64_t sfVecSize);
37+
void silu_and_mul_nvfp4_batched_quantize(Tensor const& self, Tensor const& mask,
38+
Tensor const& globalScale, Tensor valueE2M1,
39+
Tensor scaleFP8SF, int64_t sfVecSize);

docs/api/fp4_quantization.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Core Quantization Functions
1818
nvfp4_batched_quantize
1919
nvfp4_block_scale_interleave
2020
e2m1_and_ufp8sf_scale_to_float
21-
silu_and_mul_fp4_batched_quantize
21+
silu_and_mul_nvfp4_batched_quantize
2222

2323
Matrix Shuffling Utilities
2424
--------------------------

flashinfer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from .activation import gelu_tanh_and_mul as gelu_tanh_and_mul
2626
from .activation import silu_and_mul as silu_and_mul
2727
from .activation import (
28-
silu_and_mul_fp4_batched_quantize as silu_and_mul_fp4_batched_quantize,
28+
silu_and_mul_nvfp4_batched_quantize as silu_and_mul_nvfp4_batched_quantize,
2929
)
3030
from .attention import BatchAttention as BatchAttention
3131
from .attention import (

flashinfer/activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def silu_and_mul(
142142
return out
143143

144144

145-
def silu_and_mul_fp4_batched_quantize(
145+
def silu_and_mul_nvfp4_batched_quantize(
146146
a,
147147
mask,
148148
a_global_sf,
@@ -166,7 +166,7 @@ def silu_and_mul_fp4_batched_quantize(
166166
device_arch = f"{major * 10 + minor}"
167167
a_fp4, a_sf = get_fp4_quantization_module(
168168
device_arch
169-
).silu_and_mul_fp4_batched_quantize_sm100(
169+
).silu_and_mul_nvfp4_batched_quantize_sm100(
170170
a,
171171
mask,
172172
a_global_sf,

flashinfer/fp4_quantization.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,10 +375,10 @@ def _fp4_batched_quantize_sm100(
375375
)
376376

377377
@register_custom_op(
378-
"flashinfer::silu_and_mul_fp4_batched_quantize_sm100",
378+
"flashinfer::silu_and_mul_nvfp4_batched_quantize_sm100",
379379
mutates_args=("",),
380380
)
381-
def silu_and_mul_fp4_batched_quantize_sm100(
381+
def silu_and_mul_nvfp4_batched_quantize_sm100(
382382
input: torch.Tensor,
383383
mask: torch.Tensor,
384384
global_scale: Optional[torch.Tensor] = None,
@@ -429,7 +429,7 @@ def silu_and_mul_fp4_batched_quantize_sm100(
429429
dtype=torch.uint8,
430430
device=input.device,
431431
)
432-
module.silu_and_mul_fp4_batched_quantize(
432+
module.silu_and_mul_nvfp4_batched_quantize(
433433
input,
434434
mask,
435435
global_scale,
@@ -439,8 +439,8 @@ def silu_and_mul_fp4_batched_quantize_sm100(
439439
)
440440
return out_val, out_sf
441441

442-
@register_fake_op("flashinfer::silu_and_mul_fp4_batched_quantize_sm100")
443-
def _silu_and_mul_fp4_batched_quantize_sm100(
442+
@register_fake_op("flashinfer::silu_and_mul_nvfp4_batched_quantize_sm100")
443+
def _silu_and_mul_nvfp4_batched_quantize_sm100(
444444
input: torch.Tensor,
445445
mask: torch.Tensor,
446446
global_scale: Optional[torch.Tensor] = None,
@@ -518,7 +518,7 @@ def _fake_e2m1_and_ufp8sf_scale_to_float_sm100(
518518
e2m1_and_ufp8sf_scale_to_float_sm100=e2m1_and_ufp8sf_scale_to_float_sm100,
519519
mxfp4_dequantize_host=mxfp4_dequantize_host,
520520
fp4_batched_quantize_sm100=fp4_batched_quantize_sm100,
521-
silu_and_mul_fp4_batched_quantize_sm100=silu_and_mul_fp4_batched_quantize_sm100,
521+
silu_and_mul_nvfp4_batched_quantize_sm100=silu_and_mul_nvfp4_batched_quantize_sm100,
522522
)
523523

524524

tests/test_fp4_quantize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
mxfp4_quantize,
1212
mxfp4_dequantize,
1313
nvfp4_batched_quantize,
14-
silu_and_mul_fp4_batched_quantize,
14+
silu_and_mul_nvfp4_batched_quantize,
1515
)
1616
from flashinfer.utils import is_sm100a_supported
1717

@@ -377,13 +377,13 @@ def test_nvfp4_batched_quantize(
377377
@pytest.mark.parametrize("seed", SEEDS)
378378
@pytest.mark.parametrize("device", CUDA_DEVICES)
379379
@torch.inference_mode()
380-
def test_silu_and_mul_fp4_batched_quantize(
380+
def test_silu_and_mul_nvfp4_batched_quantize(
381381
dtype: torch.dtype,
382382
batch_shape: tuple[int, int, int],
383383
seed: int,
384384
device: str,
385385
) -> None:
386-
"""Test silu_and_mul_fp4_batched_quantize function."""
386+
"""Test silu_and_mul_nvfp4_batched_quantize function."""
387387
if not is_sm100a_supported(torch.device(device)):
388388
pytest.skip("Nvfp4 Requires compute capability of 10 or above")
389389
torch.set_default_device(device)
@@ -399,7 +399,7 @@ def test_silu_and_mul_fp4_batched_quantize(
399399
tensor_amax = ref_y.abs().amax(dim=(1, 2)).to(torch.float32)
400400
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
401401

402-
out, out_scale = silu_and_mul_fp4_batched_quantize(x, mask, global_scale)
402+
out, out_scale = silu_and_mul_nvfp4_batched_quantize(x, mask, global_scale)
403403
ref_out, ref_out_scale = nvfp4_batched_quantize(
404404
ref_y,
405405
global_scale,

0 commit comments

Comments
 (0)