Skip to content

Commit 6b55e4d

Browse files
committed
Address comments
Signed-off-by: Shu Wang. <[email protected]>
1 parent 856f918 commit 6b55e4d

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

csrc/nv_internal/cpp/kernels/quantization.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ void invokeSiluAndMulNVFP4Quantization(void* output, void* output_scale, void* i
319319

320320
// TODO(kaixih@nvidia): Should relax this to allow any grid size.
321321
// [email protected]: only deal with mask case
322-
assert(mask != nullptr);
322+
TLLM_CHECK_WITH_INFO(mask != nullptr, "mask must be non-null for expert NVFP4 path");
323323
grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;
324324
cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(
325325
m_topk, k, reinterpret_cast<T*>(input), reinterpret_cast<float*>(input_global_scale),

csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ void fp4_batched_quantize(Tensor self, Tensor globalScale, Tensor valueE2M1, Ten
159159
tensorrt_llm::kernels::invokeFP4Quantization<T, SF_VEC_SIZE>( \
160160
b, m, k, reinterpret_cast<T*>(self->data), static_cast<float*>(globalScale->data), \
161161
reinterpret_cast<int64_t*>(valueE2M1->data), reinterpret_cast<int32_t*>(scaleFP8SF->data), \
162-
sfUseUE8M0, layout, mMultiProcessorCount, get_stream(self->device));
162+
sfUseUE8M0, layout, mMultiProcessorCount, /*enable_pdl=*/false, get_stream(self->device));
163163

164164
if (self->dtype == dl_float16) {
165165
LAUNCH_FP4_QUANTIZE_KERNEL(half, 16)

flashinfer/fp4_quantization.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,13 @@ def _fake_fp4_batched_quantize_sm100(
362362
sf_vec_size: int = 16,
363363
sf_use_ue8m0: bool = False,
364364
) -> Tuple[torch.Tensor, torch.Tensor]:
365-
m, k = input.shape
365+
b, m, k = input.shape
366366
return (
367-
input.new_empty([m, k // 2], dtype=torch.int64), # float4_e2m1_x2
368-
input.new_empty([m * k // sf_vec_size], dtype=torch.int32), # Scale factors
367+
input.new_empty([b, m, k // 2], dtype=torch.uint8), # FLOAT4_E2M1X2
368+
input.new_empty(
369+
[b, _compute_swizzled_layout_sf_size(m, k // sf_vec_size, 128)],
370+
dtype=torch.uint8,
371+
), # swizzled SF buffer
369372
)
370373

371374
@register_custom_op(

0 commit comments

Comments
 (0)