4343 QuantizationConfig ,
4444 QuantizeMethodBase ,
4545)
46- from vllm .model_executor .layers .quantization .input_quant_fp8 import QuantFP8
4746from vllm .model_executor .layers .quantization .kv_cache import BaseKVCacheMethod
4847from vllm .model_executor .layers .quantization .utils .flashinfer_utils import (
4948 FlashinferMoeBackend ,
9594from vllm .platforms import current_platform
9695from vllm .scalar_type import scalar_types
9796from vllm .utils .deep_gemm import (
98- fp8_gemm_nt ,
9997 get_col_major_tma_aligned_tensor ,
10098 is_deep_gemm_e8m0_used ,
10199 is_deep_gemm_supported ,
102- should_use_deepgemm_for_fp8_linear ,
103100)
104101from vllm .utils .flashinfer import has_flashinfer_moe
105102from vllm .utils .import_utils import has_deep_gemm
@@ -554,83 +551,19 @@ def apply(
554551 # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
555552 # we will use BF16 dequant when DeepGEMM is not supported.
556553 if vllm_is_batch_invariant ():
557- # Call is_deep_gemm_supported() ahead of time for torch.compile
558- # dynamo has trouble tracing through
559- if self .block_quant and should_use_deepgemm_for_fp8_linear (
560- torch .bfloat16 , layer .weight , self .use_deep_gemm
561- ):
562- # use group quant consistent with block size across K
563- assert self .act_q_group_shape is not None
564- q_input , input_scale = QuantFP8 (
565- False ,
566- self .act_q_group_shape ,
567- column_major_scales = True ,
568- )(x )
569-
570- output_2d = torch .empty (
571- (q_input .shape [0 ], layer .weight .shape [0 ]),
572- dtype = torch .bfloat16 ,
573- device = q_input .device ,
574- )
575- fp8_gemm_nt (
576- (q_input , input_scale ),
577- (layer .weight , layer .weight_scale ),
578- output_2d ,
579- )
580- if bias is not None :
581- output_2d = output_2d + bias
582- return output_2d
583-
584- # Dequantize FP8 weights to BF16
585- weight_fp8 = layer .weight .to (torch .bfloat16 )
586- weight_scale = layer .weight_scale .to (torch .bfloat16 )
587-
588- # Handle different quantization granularities
589554 if self .block_quant :
590- # Block-wise quantization:
591- # - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
592- # - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
593555 assert self .weight_block_size is not None
594- block_n , block_k = self .weight_block_size # Note: order is [N, K]
595-
596- N , K = weight_fp8 .shape
597-
598- # determine expected number of blocks along N and K
599- num_blocks_n = (N + block_n - 1 ) // block_n
600- num_blocks_k = (K + block_k - 1 ) // block_k
601-
602- # scale layout may be [num_blocks_n, num_blocks_k]
603- # or [num_blocks_k, num_blocks_n] depending on backend
604- if weight_scale .dim () != 2 :
605- raise RuntimeError (
606- f"FP8 block scale must be 2D, got { tuple (weight_scale .shape )} "
607- )
608-
609- scale_rows , scale_cols = weight_scale .shape
610- if (scale_rows , scale_cols ) == (num_blocks_k , num_blocks_n ):
611- if num_blocks_n == num_blocks_k :
612- # ambiguous square case, warn and skip transpose
613- logger .warning (
614- "Batch-invariant FP8: square block-scale %dx%d; "
615- "skipping transpose to avoid misorientation." ,
616- scale_rows ,
617- scale_cols ,
618- )
619- else :
620- # clear KN -> transpose to NK
621- weight_scale = weight_scale .t ()
622-
623- # Expand scale to match weight dimensions
624- # scale_expanded should have shape [N, K]
625- scale_expanded = weight_scale .repeat_interleave (
626- block_n , dim = 0
627- ).repeat_interleave (block_k , dim = 1 )
628- # Trim to exact weight size (in case of padding)
629- scale_expanded = scale_expanded [:N , :K ]
630- weight_bf16 = weight_fp8 * scale_expanded
556+ return self .w8a8_block_fp8_linear .apply (
557+ input = x ,
558+ weight = layer .weight ,
559+ weight_scale = layer .weight_scale ,
560+ input_scale = layer .input_scale ,
561+ bias = bias ,
562+ )
631563 else :
632- # Per-tensor quantization: weight IS transposed to [K, N]
633- # scale should be scalar or [1] or per-output-channel [N]
564+ # per-tensor/channel: dequant to BF16 and run GEMM
565+ weight_fp8 = layer .weight .to (torch .bfloat16 )
566+ weight_scale = layer .weight_scale .to (torch .bfloat16 )
634567 if weight_scale .numel () == 1 :
635568 # Per-tensor: simple scalar multiplication
636569 weight_bf16 = weight_fp8 * weight_scale
@@ -649,16 +582,7 @@ def apply(
649582 else :
650583 # Fallback
651584 weight_bf16 = weight_fp8 * weight_scale
652-
653- # For block quant, weight is [N, K], for per-tensor it's [K, N]
654- # F.linear expects weight to be [N, K], so:
655- if self .block_quant :
656- # Already in correct shape [N, K]
657- output = torch .nn .functional .linear (x , weight_bf16 , bias )
658- else :
659- # Need to transpose back: [K, N] -> [N, K]
660- output = torch .nn .functional .linear (x , weight_bf16 .t (), bias )
661- return output
585+ return torch .nn .functional .linear (x , weight_bf16 .t (), bias )
662586
663587 if self .use_marlin :
664588 return apply_fp8_marlin_linear (
0 commit comments