Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 11 additions & 87 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
Expand Down Expand Up @@ -95,11 +94,9 @@
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
)
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
Expand Down Expand Up @@ -554,83 +551,19 @@ def apply(
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant():
# Call is_deep_gemm_supported() ahead of time for torch.compile
# dynamo has trouble tracing through
if self.block_quant and should_use_deepgemm_for_fp8_linear(
torch.bfloat16, layer.weight, self.use_deep_gemm
):
# use group quant consistent with block size across K
assert self.act_q_group_shape is not None
q_input, input_scale = QuantFP8(
False,
self.act_q_group_shape,
column_major_scales=True,
)(x)

output_2d = torch.empty(
(q_input.shape[0], layer.weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device,
)
fp8_gemm_nt(
(q_input, input_scale),
(layer.weight, layer.weight_scale),
output_2d,
)
if bias is not None:
output_2d = output_2d + bias
return output_2d

# Dequantize FP8 weights to BF16
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)

# Handle different quantization granularities
if self.block_quant:
# Block-wise quantization:
# - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
# - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
assert self.weight_block_size is not None
block_n, block_k = self.weight_block_size # Note: order is [N, K]

N, K = weight_fp8.shape

# determine expected number of blocks along N and K
num_blocks_n = (N + block_n - 1) // block_n
num_blocks_k = (K + block_k - 1) // block_k

# scale layout may be [num_blocks_n, num_blocks_k]
# or [num_blocks_k, num_blocks_n] depending on backend
if weight_scale.dim() != 2:
raise RuntimeError(
f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}"
)

scale_rows, scale_cols = weight_scale.shape
if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n):
if num_blocks_n == num_blocks_k:
# ambiguous square case, warn and skip transpose
logger.warning(
"Batch-invariant FP8: square block-scale %dx%d; "
"skipping transpose to avoid misorientation.",
scale_rows,
scale_cols,
)
else:
# clear KN -> transpose to NK
weight_scale = weight_scale.t()

# Expand scale to match weight dimensions
# scale_expanded should have shape [N, K]
scale_expanded = weight_scale.repeat_interleave(
block_n, dim=0
).repeat_interleave(block_k, dim=1)
# Trim to exact weight size (in case of padding)
scale_expanded = scale_expanded[:N, :K]
weight_bf16 = weight_fp8 * scale_expanded
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
else:
# Per-tensor quantization: weight IS transposed to [K, N]
# scale should be scalar or [1] or per-output-channel [N]
# per-tensor/channel: dequant to BF16 and run GEMM
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
Expand All @@ -649,16 +582,7 @@ def apply(
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale

# For block quant, weight is [N, K], for per-tensor it's [K, N]
# F.linear expects weight to be [N, K], so:
if self.block_quant:
# Already in correct shape [N, K]
output = torch.nn.functional.linear(x, weight_bf16, bias)
else:
# Need to transpose back: [K, N] -> [N, K]
output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
return output
return torch.nn.functional.linear(x, weight_bf16.t(), bias)

if self.use_marlin:
return apply_fp8_marlin_linear(
Expand Down