Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions tests/v1/generation/test_batch_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform

hopper_only = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.is_device_capability(90)),
reason="Requires CUDA and Hopper (SM90)",
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
reason="Requires CUDA and >= Hopper (SM90)",
)


Expand Down Expand Up @@ -74,7 +74,7 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
return base_prompt


@hopper_only
@skip_unsupported
@pytest.mark.timeout(1000)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
"""
Expand Down Expand Up @@ -219,7 +219,7 @@ def _extract_step_logprobs(request_output):
return None, None


@hopper_only
@skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
@pytest.mark.forked
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
Expand Down Expand Up @@ -434,7 +434,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
pytest.fail(msg)


@hopper_only
@skip_unsupported
def test_simple_generation():
"""
Simple test that runs the model with a basic prompt and prints the output.
Expand Down Expand Up @@ -480,7 +480,7 @@ def test_simple_generation():
llm.shutdown()


@hopper_only
@skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
@pytest.mark.forked
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
Expand Down Expand Up @@ -707,7 +707,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
os.environ["VLLM_BATCH_INVARIANT"] = old_value


@hopper_only
@skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.forked
def test_decode_logprobs_match_prefill_logprobs(backend):
Expand Down
18 changes: 9 additions & 9 deletions tests/v1/generation/test_rms_norm_batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform

hopper_only = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.is_device_capability(90)),
reason="Requires CUDA and Hopper (SM90)",
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
reason="Requires CUDA and >= Hopper (SM90)",
)


@hopper_only
@skip_unsupported
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_rms_norm_batch_invariant_vs_standard(
)


@hopper_only
@skip_unsupported
@pytest.mark.parametrize("batch_size", [1, 16, 128])
@pytest.mark.parametrize("seq_len", [1, 32, 512])
@pytest.mark.parametrize("hidden_size", [2048, 4096])
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
)


@hopper_only
@skip_unsupported
def test_rms_norm_numerical_stability():
"""
Test RMS norm numerical stability with extreme values.
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_rms_norm_numerical_stability():
)


@hopper_only
@skip_unsupported
def test_rms_norm_formula():
"""
Test that RMS norm follows the correct mathematical formula.
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_rms_norm_formula():
)


@hopper_only
@skip_unsupported
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
def test_rms_norm_different_hidden_sizes(hidden_size: int):
"""
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int):
)


@hopper_only
@skip_unsupported
def test_rms_norm_determinism():
"""
Test that batch-invariant RMS norm produces deterministic results.
Expand Down
58 changes: 54 additions & 4 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
Expand Down Expand Up @@ -94,9 +95,11 @@
from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
)
from vllm.utils.flashinfer import has_flashinfer_moe

Expand Down Expand Up @@ -539,8 +542,34 @@ def apply(
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# If batch invariant mode is enabled, dequantize and use BF16 compute
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant():
if self.block_quant and should_use_deepgemm_for_fp8_linear(
torch.bfloat16, layer.weight, None
):
# use group quant consistent with block size across K
assert self.act_q_group_shape is not None
q_input, input_scale = QuantFP8(
False,
self.act_q_group_shape,
column_major_scales=True,
)(x)
Comment on lines +553 to +557
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the intended way to use QuantFP8. Is there a reason why this cannot be put in the shared block w8a8 utils so it can be reused for other backends like compressed tensors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's talk offline


output_2d = torch.empty(
(q_input.shape[0], layer.weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device,
)
fp8_gemm_nt(
(q_input, input_scale),
(layer.weight, layer.weight_scale),
output_2d,
)
if bias is not None:
output_2d = output_2d + bias
return output_2d

# Dequantize FP8 weights to BF16
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
Expand All @@ -555,9 +584,30 @@ def apply(

N, K = weight_fp8.shape

# Scale is stored transposed: [num_blocks_k, num_blocks_n]
# We need to transpose it to [num_blocks_n, num_blocks_k] first
weight_scale = weight_scale.t()
# determine expected number of blocks along N and K
num_blocks_n = (N + block_n - 1) // block_n
num_blocks_k = (K + block_k - 1) // block_k

# scale layout may be [num_blocks_n, num_blocks_k]
# or [num_blocks_k, num_blocks_n] depending on backend
if weight_scale.dim() != 2:
raise RuntimeError(
f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}"
)

scale_rows, scale_cols = weight_scale.shape
if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n):
if num_blocks_n == num_blocks_k:
# ambiguous square case, warn and skip transpose
logger.warning(
"Batch-invariant FP8: square block-scale %dx%d; "
"skipping transpose to avoid misorientation.",
scale_rows,
scale_cols,
)
else:
# clear KN -> transpose to NK
weight_scale = weight_scale.t()

# Expand scale to match weight dimensions
# scale_expanded should have shape [N, K]
Expand Down