Skip to content

Commit 35d801f

Browse files
authored
[Feature] Refactor batch invariant fp8 DeepGEMM (#27606)
Signed-off-by: yewentao256 <[email protected]>
1 parent 0bf29fa commit 35d801f

File tree

1 file changed

+11
-87
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+11
-87
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 11 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
QuantizationConfig,
4444
QuantizeMethodBase,
4545
)
46-
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
4746
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
4847
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
4948
FlashinferMoeBackend,
@@ -95,11 +94,9 @@
9594
from vllm.platforms import current_platform
9695
from vllm.scalar_type import scalar_types
9796
from vllm.utils.deep_gemm import (
98-
fp8_gemm_nt,
9997
get_col_major_tma_aligned_tensor,
10098
is_deep_gemm_e8m0_used,
10199
is_deep_gemm_supported,
102-
should_use_deepgemm_for_fp8_linear,
103100
)
104101
from vllm.utils.flashinfer import has_flashinfer_moe
105102
from vllm.utils.import_utils import has_deep_gemm
@@ -554,83 +551,19 @@ def apply(
554551
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
555552
# we will use BF16 dequant when DeepGEMM is not supported.
556553
if vllm_is_batch_invariant():
557-
# Call is_deep_gemm_supported() ahead of time for torch.compile
558-
# dynamo has trouble tracing through
559-
if self.block_quant and should_use_deepgemm_for_fp8_linear(
560-
torch.bfloat16, layer.weight, self.use_deep_gemm
561-
):
562-
# use group quant consistent with block size across K
563-
assert self.act_q_group_shape is not None
564-
q_input, input_scale = QuantFP8(
565-
False,
566-
self.act_q_group_shape,
567-
column_major_scales=True,
568-
)(x)
569-
570-
output_2d = torch.empty(
571-
(q_input.shape[0], layer.weight.shape[0]),
572-
dtype=torch.bfloat16,
573-
device=q_input.device,
574-
)
575-
fp8_gemm_nt(
576-
(q_input, input_scale),
577-
(layer.weight, layer.weight_scale),
578-
output_2d,
579-
)
580-
if bias is not None:
581-
output_2d = output_2d + bias
582-
return output_2d
583-
584-
# Dequantize FP8 weights to BF16
585-
weight_fp8 = layer.weight.to(torch.bfloat16)
586-
weight_scale = layer.weight_scale.to(torch.bfloat16)
587-
588-
# Handle different quantization granularities
589554
if self.block_quant:
590-
# Block-wise quantization:
591-
# - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
592-
# - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
593555
assert self.weight_block_size is not None
594-
block_n, block_k = self.weight_block_size # Note: order is [N, K]
595-
596-
N, K = weight_fp8.shape
597-
598-
# determine expected number of blocks along N and K
599-
num_blocks_n = (N + block_n - 1) // block_n
600-
num_blocks_k = (K + block_k - 1) // block_k
601-
602-
# scale layout may be [num_blocks_n, num_blocks_k]
603-
# or [num_blocks_k, num_blocks_n] depending on backend
604-
if weight_scale.dim() != 2:
605-
raise RuntimeError(
606-
f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}"
607-
)
608-
609-
scale_rows, scale_cols = weight_scale.shape
610-
if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n):
611-
if num_blocks_n == num_blocks_k:
612-
# ambiguous square case, warn and skip transpose
613-
logger.warning(
614-
"Batch-invariant FP8: square block-scale %dx%d; "
615-
"skipping transpose to avoid misorientation.",
616-
scale_rows,
617-
scale_cols,
618-
)
619-
else:
620-
# clear KN -> transpose to NK
621-
weight_scale = weight_scale.t()
622-
623-
# Expand scale to match weight dimensions
624-
# scale_expanded should have shape [N, K]
625-
scale_expanded = weight_scale.repeat_interleave(
626-
block_n, dim=0
627-
).repeat_interleave(block_k, dim=1)
628-
# Trim to exact weight size (in case of padding)
629-
scale_expanded = scale_expanded[:N, :K]
630-
weight_bf16 = weight_fp8 * scale_expanded
556+
return self.w8a8_block_fp8_linear.apply(
557+
input=x,
558+
weight=layer.weight,
559+
weight_scale=layer.weight_scale,
560+
input_scale=layer.input_scale,
561+
bias=bias,
562+
)
631563
else:
632-
# Per-tensor quantization: weight IS transposed to [K, N]
633-
# scale should be scalar or [1] or per-output-channel [N]
564+
# per-tensor/channel: dequant to BF16 and run GEMM
565+
weight_fp8 = layer.weight.to(torch.bfloat16)
566+
weight_scale = layer.weight_scale.to(torch.bfloat16)
634567
if weight_scale.numel() == 1:
635568
# Per-tensor: simple scalar multiplication
636569
weight_bf16 = weight_fp8 * weight_scale
@@ -649,16 +582,7 @@ def apply(
649582
else:
650583
# Fallback
651584
weight_bf16 = weight_fp8 * weight_scale
652-
653-
# For block quant, weight is [N, K], for per-tensor it's [K, N]
654-
# F.linear expects weight to be [N, K], so:
655-
if self.block_quant:
656-
# Already in correct shape [N, K]
657-
output = torch.nn.functional.linear(x, weight_bf16, bias)
658-
else:
659-
# Need to transpose back: [K, N] -> [N, K]
660-
output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
661-
return output
585+
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
662586

663587
if self.use_marlin:
664588
return apply_fp8_marlin_linear(

0 commit comments

Comments
 (0)