Skip to content

Commit ccf25ca

Browse files
qli88tjtanaa
andcommitted
[Bugfix] Add padding for block-scale fused-moe weights for AITER lib
Co-authored-by: tjtanaavllm <[email protected]> Signed-off-by: Qiang Li <[email protected]>
1 parent aa49f14 commit ccf25ca

File tree

1 file changed

+116
-0
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+116
-0
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,113 @@ def __init__(self, quant_config: Fp8Config):
470470
block_shape=self.quant_config.weight_block_size,
471471
allow_deep_gemm=self.allow_deep_gemm)
472472

473+
def _maybe_pad_rocm_aiter_block_scaled_fused_moe_weights(
474+
self,
475+
w2_weight,
476+
w2_weight_scale_inv,
477+
w13_weight,
478+
w13_weight_scale_inv,
479+
block_k=128,
480+
block_n=128):
481+
"""
482+
Pads the MoE weights and scales to align with block quantization
483+
requirements.
484+
485+
aiter.fmoe_fp8_blockscale_g1u1 only support out dtype = bf16,
486+
inter_dim % 256 = 0 and fc_scale_blkn and fc_scale_blkk is 128
487+
"""
488+
489+
if (not self.rocm_aiter_moe_enabled):
490+
return (w2_weight, w2_weight_scale_inv, w13_weight,
491+
w13_weight_scale_inv)
492+
493+
if (self.rocm_aiter_moe_enabled
494+
and (w2_weight.shape[-1] % 256 == 0
495+
and w13_weight.shape[-2] % 256 == 0)):
496+
return (w2_weight, w2_weight_scale_inv, w13_weight,
497+
w13_weight_scale_inv)
498+
499+
logger.info_once(
500+
"ROCm AITER Padding MoE weights and scales for block quantization."
501+
)
502+
# for now this is enabled for DeepSeekV3 and Qwen3
503+
assert block_k == 128, "block_k must be 128"
504+
assert block_n == 128, "block_n must be 128"
505+
assert block_k == block_n, (
506+
"block_k and block_n must be the same value: 128")
507+
508+
num_experts, hidden_size, inter_dim = w2_weight.shape
509+
padded_inter_dim = ((inter_dim + 255) // 256) * 256
510+
# inter_dim_block_scale = layer.w2_weight_scale_inv.shape[2]
511+
# = ((intermediate_size_per_partition + block_n - 1) // block_n)
512+
inter_dim_block_scale = (inter_dim + block_n - 1) // block_n
513+
padded_inter_dim_block_scale = ((padded_inter_dim + block_n - 1) //
514+
block_n)
515+
516+
# k_block_scale is also known as hidden_size_block
517+
# Pad w2_weight to
518+
# [num_experts, hidden_size, inter_dim]
519+
# Padding Logic:
520+
# [expert(local_expert:EP), hidden_size, inter_dim]
521+
# after padding inter_dim with 0.0 to multiple of 256
522+
# [expert(local_expert:EP), hidden_size, padded_inter_dim]
523+
if padded_inter_dim > inter_dim:
524+
pad_size = padded_inter_dim - inter_dim
525+
w2_weight = F.pad(w2_weight, (0, pad_size), value=0.0)
526+
527+
# Pad w2_weight_scale_inv to
528+
# [num_experts, k_block_scale, inter_dim_block_scale]
529+
# Padding Logic:
530+
# [expert(local_expert:EP), k_block_scale, inter_dim_block_scale]
531+
# after padding inter_dim with 1.0
532+
# [expert(local_expert:EP), k_block_scale, padded_inter_dim_block_scale] # noqa: E501
533+
if padded_inter_dim_block_scale > inter_dim_block_scale:
534+
pad_size = padded_inter_dim_block_scale - inter_dim_block_scale
535+
w2_weight_scale_inv = F.pad(w2_weight_scale_inv, (0, pad_size),
536+
value=1.0)
537+
538+
# Pad w13_weight to
539+
# [num_experts, 2 * inter_dim, hidden_size]
540+
# Padding Logic:
541+
# [expert(local_expert:EP), inter_dim*2, dim]
542+
# after reshape
543+
# [expert(local_expert:EP), 2, inter_dim, dim]
544+
# after right padding
545+
# [expert(local_expert:EP), 2, padded_inter_dim, dim]
546+
# after reshape
547+
# [expert(local_expert:EP), 2 * padded_inter_dim, dim]
548+
w13_weight = w13_weight.view(num_experts, 2, inter_dim, hidden_size)
549+
if padded_inter_dim > inter_dim:
550+
pad_size = padded_inter_dim - inter_dim
551+
w13_weight = F.pad(w13_weight, (0, 0, 0, pad_size), value=0.0)
552+
w13_weight = w13_weight.view(num_experts, 2 * padded_inter_dim,
553+
hidden_size)
554+
555+
# Pad w13_weight_scale_inv to
556+
# [num_experts, 2 * inter_dim_block_scale, k_block_scale]
557+
# Padding Logic:
558+
# k_block_scale = ((hidden_size + block_k - 1) // block_k)
559+
# [expert(local_expert:EP), inter_dim_block_scale*2, k_block_scale] # noqa: E501
560+
# after reshape
561+
# [expert(local_expert:EP), 2, inter_dim_block_scale, k_block_scale] # noqa: E501
562+
# after right padding with 1.0
563+
# [expert(local_expert:EP), 2, padded_inter_dim_block_scale, k_block_scale] # noqa: E501
564+
# after reshape
565+
# [expert(local_expert:EP), 2 * padded_inter_dim_block_scale, k_block_scale] # noqa: E501
566+
k_block_scale = w13_weight_scale_inv.shape[
567+
2] # k_block_scale = (hidden_size + block_k - 1) // block_k
568+
w13_weight_scale_inv = w13_weight_scale_inv.view(
569+
num_experts, 2, inter_dim_block_scale, k_block_scale)
570+
if padded_inter_dim_block_scale > inter_dim_block_scale:
571+
pad_size = padded_inter_dim_block_scale - inter_dim_block_scale
572+
w13_weight_scale_inv = F.pad(w13_weight_scale_inv,
573+
(0, 0, 0, pad_size),
574+
value=1.0)
575+
w13_weight_scale_inv = w13_weight_scale_inv.view(
576+
num_experts, 2 * padded_inter_dim_block_scale, k_block_scale)
577+
578+
return w2_weight, w2_weight_scale_inv, w13_weight, w13_weight_scale_inv
579+
473580
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
474581
intermediate_size_per_partition: int,
475582
params_dtype: torch.dtype, **extra_weight_attrs):
@@ -623,6 +730,15 @@ def process_weights_after_loading(self, layer: Module) -> None:
623730
w2_weight = layer.w2_weight
624731
w2_weight_scale_inv = layer.w2_weight_scale_inv
625732

733+
(w2_weight, w2_weight_scale_inv, w13_weight, w13_weight_scale_inv
734+
) = self._maybe_pad_rocm_aiter_block_scaled_fused_moe_weights(
735+
w2_weight,
736+
w2_weight_scale_inv,
737+
w13_weight,
738+
w13_weight_scale_inv,
739+
block_n=self.quant_config.weight_block_size[0],
740+
block_k=self.quant_config.weight_block_size[1])
741+
626742
# torch.compile() cannot use Parameter subclasses.
627743
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
628744
layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,

0 commit comments

Comments
 (0)