@@ -470,6 +470,113 @@ def __init__(self, quant_config: Fp8Config):
470470 block_shape = self .quant_config .weight_block_size ,
471471 allow_deep_gemm = self .allow_deep_gemm )
472472
473+ def _maybe_pad_rocm_aiter_block_scaled_fused_moe_weights (
474+ self ,
475+ w2_weight ,
476+ w2_weight_scale_inv ,
477+ w13_weight ,
478+ w13_weight_scale_inv ,
479+ block_k = 128 ,
480+ block_n = 128 ):
481+ """
482+ Pads the MoE weights and scales to align with block quantization
483+ requirements.
484+
485+ aiter.fmoe_fp8_blockscale_g1u1 only support out dtype = bf16,
486+ inter_dim % 256 = 0 and fc_scale_blkn and fc_scale_blkk is 128
487+ """
488+
489+ if (not self .rocm_aiter_moe_enabled ):
490+ return (w2_weight , w2_weight_scale_inv , w13_weight ,
491+ w13_weight_scale_inv )
492+
493+ if (self .rocm_aiter_moe_enabled
494+ and (w2_weight .shape [- 1 ] % 256 == 0
495+ and w13_weight .shape [- 2 ] % 256 == 0 )):
496+ return (w2_weight , w2_weight_scale_inv , w13_weight ,
497+ w13_weight_scale_inv )
498+
499+ logger .info_once (
500+ "ROCm AITER Padding MoE weights and scales for block quantization."
501+ )
502+ # for now this is enabled for DeepSeekV3 and Qwen3
503+ assert block_k == 128 , "block_k must be 128"
504+ assert block_n == 128 , "block_n must be 128"
505+ assert block_k == block_n , (
506+ "block_k and block_n must be the same value: 128" )
507+
508+ num_experts , hidden_size , inter_dim = w2_weight .shape
509+ padded_inter_dim = ((inter_dim + 255 ) // 256 ) * 256
510+ # inter_dim_block_scale = layer.w2_weight_scale_inv.shape[2]
511+ # = ((intermediate_size_per_partition + block_n - 1) // block_n)
512+ inter_dim_block_scale = (inter_dim + block_n - 1 ) // block_n
513+ padded_inter_dim_block_scale = ((padded_inter_dim + block_n - 1 ) //
514+ block_n )
515+
516+ # k_block_scale is also known as hidden_size_block
517+ # Pad w2_weight to
518+ # [num_experts, hidden_size, inter_dim]
519+ # Padding Logic:
520+ # [expert(local_expert:EP), hidden_size, inter_dim]
521+ # after padding inter_dim with 0.0 to multiple of 256
522+ # [expert(local_expert:EP), hidden_size, padded_inter_dim]
523+ if padded_inter_dim > inter_dim :
524+ pad_size = padded_inter_dim - inter_dim
525+ w2_weight = F .pad (w2_weight , (0 , pad_size ), value = 0.0 )
526+
527+ # Pad w2_weight_scale_inv to
528+ # [num_experts, k_block_scale, inter_dim_block_scale]
529+ # Padding Logic:
530+ # [expert(local_expert:EP), k_block_scale, inter_dim_block_scale]
531+ # after padding inter_dim with 1.0
532+ # [expert(local_expert:EP), k_block_scale, padded_inter_dim_block_scale] # noqa: E501
533+ if padded_inter_dim_block_scale > inter_dim_block_scale :
534+ pad_size = padded_inter_dim_block_scale - inter_dim_block_scale
535+ w2_weight_scale_inv = F .pad (w2_weight_scale_inv , (0 , pad_size ),
536+ value = 1.0 )
537+
538+ # Pad w13_weight to
539+ # [num_experts, 2 * inter_dim, hidden_size]
540+ # Padding Logic:
541+ # [expert(local_expert:EP), inter_dim*2, dim]
542+ # after reshape
543+ # [expert(local_expert:EP), 2, inter_dim, dim]
544+ # after right padding
545+ # [expert(local_expert:EP), 2, padded_inter_dim, dim]
546+ # after reshape
547+ # [expert(local_expert:EP), 2 * padded_inter_dim, dim]
548+ w13_weight = w13_weight .view (num_experts , 2 , inter_dim , hidden_size )
549+ if padded_inter_dim > inter_dim :
550+ pad_size = padded_inter_dim - inter_dim
551+ w13_weight = F .pad (w13_weight , (0 , 0 , 0 , pad_size ), value = 0.0 )
552+ w13_weight = w13_weight .view (num_experts , 2 * padded_inter_dim ,
553+ hidden_size )
554+
555+ # Pad w13_weight_scale_inv to
556+ # [num_experts, 2 * inter_dim_block_scale, k_block_scale]
557+ # Padding Logic:
558+ # k_block_scale = ((hidden_size + block_k - 1) // block_k)
559+ # [expert(local_expert:EP), inter_dim_block_scale*2, k_block_scale] # noqa: E501
560+ # after reshape
561+ # [expert(local_expert:EP), 2, inter_dim_block_scale, k_block_scale] # noqa: E501
562+ # after right padding with 1.0
563+ # [expert(local_expert:EP), 2, padded_inter_dim_block_scale, k_block_scale] # noqa: E501
564+ # after reshape
565+ # [expert(local_expert:EP), 2 * padded_inter_dim_block_scale, k_block_scale] # noqa: E501
566+ k_block_scale = w13_weight_scale_inv .shape [
567+ 2 ] # k_block_scale = (hidden_size + block_k - 1) // block_k
568+ w13_weight_scale_inv = w13_weight_scale_inv .view (
569+ num_experts , 2 , inter_dim_block_scale , k_block_scale )
570+ if padded_inter_dim_block_scale > inter_dim_block_scale :
571+ pad_size = padded_inter_dim_block_scale - inter_dim_block_scale
572+ w13_weight_scale_inv = F .pad (w13_weight_scale_inv ,
573+ (0 , 0 , 0 , pad_size ),
574+ value = 1.0 )
575+ w13_weight_scale_inv = w13_weight_scale_inv .view (
576+ num_experts , 2 * padded_inter_dim_block_scale , k_block_scale )
577+
578+ return w2_weight , w2_weight_scale_inv , w13_weight , w13_weight_scale_inv
579+
473580 def create_weights (self , layer : Module , num_experts : int , hidden_size : int ,
474581 intermediate_size_per_partition : int ,
475582 params_dtype : torch .dtype , ** extra_weight_attrs ):
@@ -623,6 +730,15 @@ def process_weights_after_loading(self, layer: Module) -> None:
623730 w2_weight = layer .w2_weight
624731 w2_weight_scale_inv = layer .w2_weight_scale_inv
625732
733+ (w2_weight , w2_weight_scale_inv , w13_weight , w13_weight_scale_inv
734+ ) = self ._maybe_pad_rocm_aiter_block_scaled_fused_moe_weights (
735+ w2_weight ,
736+ w2_weight_scale_inv ,
737+ w13_weight ,
738+ w13_weight_scale_inv ,
739+ block_n = self .quant_config .weight_block_size [0 ],
740+ block_k = self .quant_config .weight_block_size [1 ])
741+
626742 # torch.compile() cannot use Parameter subclasses.
627743 layer .w13_weight = Parameter (w13_weight , requires_grad = False )
628744 layer .w13_weight_scale_inv = Parameter (w13_weight_scale_inv ,
0 commit comments