Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,6 @@ __device__ inline void quantize_with_block_size_impl(int32_t numbatches, int32_t
static constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / ELTS_PER_THREAD;
static_assert(sizeof(PackedVec) == sizeof(Type) * ELTS_PER_THREAD, "Vec size is not matched.");

float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
bool isSfSwizzledLayout = layout == QuantizationSFLayout::SWIZZLED_128x4 ||
layout == QuantizationSFLayout::SWIZZLED_8x4;
int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8;
Expand All @@ -810,6 +809,7 @@ __device__ inline void quantize_with_block_size_impl(int32_t numbatches, int32_t
asm volatile("griddepcontrol.wait;");
for (int rowIdx = blockIdx.x; rowIdx < numPaddedRowsForSf; rowIdx += gridDim.x) {
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[batchIdx];
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
std::optional<int> optionalBatchIdx = batchIdx;
std::optional<int> optionalNumRows = numRows;
Expand Down
8 changes: 4 additions & 4 deletions tests/utils/test_fp4_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
DTYPES = [torch.float16, torch.bfloat16]
# The batch dimension doesn't need to be multiple of 128
SHAPES = [(128, 64), (256, 128), (120, 64), (200, 256)]
BATCH_SHAPES = [(2, 128, 64), (3, 256, 128), (1, 120, 64)]
BATCH_SHAPES = [(1, 256, 128), (2, 128, 64), (3, 256, 128), (1, 120, 64)]
SEEDS = [42]
CUDA_DEVICES = ["cuda:0"]

Expand Down Expand Up @@ -334,7 +334,7 @@ def test_nvfp4_batched_quantize(

b, m, n = batch_shape
x = torch.randn(batch_shape, dtype=dtype)
tensor_amax = torch.abs(x).max().to(torch.float32)
tensor_amax = torch.abs(x).amax(dim=(1, 2)).to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
mask = None
# Test the batched quantization
Expand All @@ -357,7 +357,7 @@ def test_nvfp4_batched_quantize(

# Compare with single tensor quantization for each batch
for i in range(b):
single_out, single_scale = fp4_quantize(x[i], global_scale, 16, False, True)
single_out, single_scale = fp4_quantize(x[i], global_scale[i], 16, False, True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While this change to use per-batch global_scale[i] is correct, it reveals a latent bug in the test logic for the use_mask=True case on line 366. The out_scale[i] tensor is 1D, but the unswizzle_sf function expects a 2D tensor. This will likely cause the test to fail. You should reshape out_scale[i] using the shape of single_scale before passing it to unswizzle_sf.

Specifically, line 366 should be:

scale_ans = unswizzle_sf(out_scale[i].reshape(single_scale.shape), m, n)

if use_mask:
torch.testing.assert_close(
out[i][: mask[i]], single_out[: mask[i]], rtol=1e-5, atol=1e-5
Expand Down Expand Up @@ -414,7 +414,7 @@ def test_silu_and_mul_nvfp4_batched_quantize(
for i in range(b):
x_silu_mul = silu_and_mul(x[i])
single_out, single_scale = fp4_quantize(
x_silu_mul, global_scale, 16, False, True
x_silu_mul, global_scale[i], 16, False, True
)
Comment on lines 416 to 418
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the other test, this change to use global_scale[i] is correct, but it exposes issues on lines 427-428. out_scale[i] and ref_out_scale[i] are 1D tensors, but unswizzle_sf expects 2D tensors. They should be reshaped using single_scale.shape before being passed to unswizzle_sf.

Specifically, the lines should be:

scale_ans = unswizzle_sf(out_scale[i].reshape(single_scale.shape), m, n)
ref_out_scale_expert = unswizzle_sf(ref_out_scale[i].reshape(single_scale.shape), m, n)

torch.testing.assert_close(
out[i][: mask[i]], single_out[: mask[i]], rtol=1e-5, atol=1e-5
Expand Down