From ccf25caaab6910de131e8dc4be0ba15f59e4fbc0 Mon Sep 17 00:00:00 2001 From: Qiang Li Date: Thu, 5 Jun 2025 19:31:33 +0000 Subject: [PATCH 1/5] [Bugfix] Add padding for block-scale fused-moe weights for AITER lib Co-authored-by: tjtanaavllm Signed-off-by: Qiang Li --- .../model_executor/layers/quantization/fp8.py | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5ac22b6a0aee..8eb9bc8275b3 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -470,6 +470,113 @@ def __init__(self, quant_config: Fp8Config): block_shape=self.quant_config.weight_block_size, allow_deep_gemm=self.allow_deep_gemm) + def _maybe_pad_rocm_aiter_block_scaled_fused_moe_weights( + self, + w2_weight, + w2_weight_scale_inv, + w13_weight, + w13_weight_scale_inv, + block_k=128, + block_n=128): + """ + Pads the MoE weights and scales to align with block quantization + requirements. + + aiter.fmoe_fp8_blockscale_g1u1 only support out dtype = bf16, + inter_dim % 256 = 0 and fc_scale_blkn and fc_scale_blkk is 128 + """ + + if (not self.rocm_aiter_moe_enabled): + return (w2_weight, w2_weight_scale_inv, w13_weight, + w13_weight_scale_inv) + + if (self.rocm_aiter_moe_enabled + and (w2_weight.shape[-1] % 256 == 0 + and w13_weight.shape[-2] % 256 == 0)): + return (w2_weight, w2_weight_scale_inv, w13_weight, + w13_weight_scale_inv) + + logger.info_once( + "ROCm AITER Padding MoE weights and scales for block quantization." + ) + # for now this is enabled for DeepSeekV3 and Qwen3 + assert block_k == 128, "block_k must be 128" + assert block_n == 128, "block_n must be 128" + assert block_k == block_n, ( + "block_k and block_n must be the same value: 128") + + num_experts, hidden_size, inter_dim = w2_weight.shape + padded_inter_dim = ((inter_dim + 255) // 256) * 256 + # inter_dim_block_scale = layer.w2_weight_scale_inv.shape[2] + # = ((intermediate_size_per_partition + block_n - 1) // block_n) + inter_dim_block_scale = (inter_dim + block_n - 1) // block_n + padded_inter_dim_block_scale = ((padded_inter_dim + block_n - 1) // + block_n) + + # k_block_scale is also known as hidden_size_block + # Pad w2_weight to + # [num_experts, hidden_size, inter_dim] + # Padding Logic: + # [expert(local_expert:EP), hidden_size, inter_dim] + # after padding inter_dim with 0.0 to multiple of 256 + # [expert(local_expert:EP), hidden_size, padded_inter_dim] + if padded_inter_dim > inter_dim: + pad_size = padded_inter_dim - inter_dim + w2_weight = F.pad(w2_weight, (0, pad_size), value=0.0) + + # Pad w2_weight_scale_inv to + # [num_experts, k_block_scale, inter_dim_block_scale] + # Padding Logic: + # [expert(local_expert:EP), k_block_scale, inter_dim_block_scale] + # after padding inter_dim with 1.0 + # [expert(local_expert:EP), k_block_scale, padded_inter_dim_block_scale] # noqa: E501 + if padded_inter_dim_block_scale > inter_dim_block_scale: + pad_size = padded_inter_dim_block_scale - inter_dim_block_scale + w2_weight_scale_inv = F.pad(w2_weight_scale_inv, (0, pad_size), + value=1.0) + + # Pad w13_weight to + # [num_experts, 2 * inter_dim, hidden_size] + # Padding Logic: + # [expert(local_expert:EP), inter_dim*2, dim] + # after reshape + # [expert(local_expert:EP), 2, inter_dim, dim] + # after right padding + # [expert(local_expert:EP), 2, padded_inter_dim, dim] + # after reshape + # [expert(local_expert:EP), 2 * padded_inter_dim, dim] + w13_weight = w13_weight.view(num_experts, 2, inter_dim, hidden_size) + if padded_inter_dim > inter_dim: + pad_size = padded_inter_dim - inter_dim + w13_weight = F.pad(w13_weight, (0, 0, 0, pad_size), value=0.0) + w13_weight = w13_weight.view(num_experts, 2 * padded_inter_dim, + hidden_size) + + # Pad w13_weight_scale_inv to + # [num_experts, 2 * inter_dim_block_scale, k_block_scale] + # Padding Logic: + # k_block_scale = ((hidden_size + block_k - 1) // block_k) + # [expert(local_expert:EP), inter_dim_block_scale*2, k_block_scale] # noqa: E501 + # after reshape + # [expert(local_expert:EP), 2, inter_dim_block_scale, k_block_scale] # noqa: E501 + # after right padding with 1.0 + # [expert(local_expert:EP), 2, padded_inter_dim_block_scale, k_block_scale] # noqa: E501 + # after reshape + # [expert(local_expert:EP), 2 * padded_inter_dim_block_scale, k_block_scale] # noqa: E501 + k_block_scale = w13_weight_scale_inv.shape[ + 2] # k_block_scale = (hidden_size + block_k - 1) // block_k + w13_weight_scale_inv = w13_weight_scale_inv.view( + num_experts, 2, inter_dim_block_scale, k_block_scale) + if padded_inter_dim_block_scale > inter_dim_block_scale: + pad_size = padded_inter_dim_block_scale - inter_dim_block_scale + w13_weight_scale_inv = F.pad(w13_weight_scale_inv, + (0, 0, 0, pad_size), + value=1.0) + w13_weight_scale_inv = w13_weight_scale_inv.view( + num_experts, 2 * padded_inter_dim_block_scale, k_block_scale) + + return w2_weight, w2_weight_scale_inv, w13_weight, w13_weight_scale_inv + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -623,6 +730,15 @@ def process_weights_after_loading(self, layer: Module) -> None: w2_weight = layer.w2_weight w2_weight_scale_inv = layer.w2_weight_scale_inv + (w2_weight, w2_weight_scale_inv, w13_weight, w13_weight_scale_inv + ) = self._maybe_pad_rocm_aiter_block_scaled_fused_moe_weights( + w2_weight, + w2_weight_scale_inv, + w13_weight, + w13_weight_scale_inv, + block_n=self.quant_config.weight_block_size[0], + block_k=self.quant_config.weight_block_size[1]) + # torch.compile() cannot use Parameter subclasses. layer.w13_weight = Parameter(w13_weight, requires_grad=False) layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv, From 45604cca797ed7bc3ea8ed604ae602526d9255ae Mon Sep 17 00:00:00 2001 From: Qiang Li Date: Thu, 5 Jun 2025 19:51:08 +0000 Subject: [PATCH 2/5] [Bugfix] Add None check for optional list Signed-off-by: Qiang Li --- vllm/model_executor/layers/quantization/fp8.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8eb9bc8275b3..9e0d9fd4bcb4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -736,8 +736,10 @@ def process_weights_after_loading(self, layer: Module) -> None: w2_weight_scale_inv, w13_weight, w13_weight_scale_inv, - block_n=self.quant_config.weight_block_size[0], - block_k=self.quant_config.weight_block_size[1]) + block_n=self.quant_config.weight_block_size[0] \ + if self.quant_config.weight_block_size[0] is not None else 128, + block_k=self.quant_config.weight_block_size[1] \ + if self.quant_config.weight_block_size[1] is not None else 128) # torch.compile() cannot use Parameter subclasses. layer.w13_weight = Parameter(w13_weight, requires_grad=False) From 1c2adb55094eefebbb7e5497c680c51359e40d54 Mon Sep 17 00:00:00 2001 From: Qiang Li Date: Thu, 5 Jun 2025 20:03:29 +0000 Subject: [PATCH 3/5] Make sure block quant is used before doing possible padding Signed-off-by: Qiang Li --- .../model_executor/layers/quantization/fp8.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9e0d9fd4bcb4..69ebf2037e71 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -730,16 +730,15 @@ def process_weights_after_loading(self, layer: Module) -> None: w2_weight = layer.w2_weight w2_weight_scale_inv = layer.w2_weight_scale_inv - (w2_weight, w2_weight_scale_inv, w13_weight, w13_weight_scale_inv - ) = self._maybe_pad_rocm_aiter_block_scaled_fused_moe_weights( - w2_weight, - w2_weight_scale_inv, - w13_weight, - w13_weight_scale_inv, - block_n=self.quant_config.weight_block_size[0] \ - if self.quant_config.weight_block_size[0] is not None else 128, - block_k=self.quant_config.weight_block_size[1] \ - if self.quant_config.weight_block_size[1] is not None else 128) + if self.block_quant: + (w2_weight, w2_weight_scale_inv, w13_weight, w13_weight_scale_inv + ) = self._maybe_pad_rocm_aiter_block_scaled_fused_moe_weights( + w2_weight, + w2_weight_scale_inv, + w13_weight, + w13_weight_scale_inv, + block_n=self.quant_config.weight_block_size[0], + block_k=self.quant_config.weight_block_size[1]) # torch.compile() cannot use Parameter subclasses. layer.w13_weight = Parameter(w13_weight, requires_grad=False) From bda243edfe101d2785720a754c7858e2d8dd5a90 Mon Sep 17 00:00:00 2001 From: Qiang Li Date: Thu, 5 Jun 2025 20:12:35 +0000 Subject: [PATCH 4/5] Replace block_quant with raw check to stop CI complain Signed-off-by: Qiang Li --- vllm/model_executor/layers/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 69ebf2037e71..e872febc3756 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -730,7 +730,7 @@ def process_weights_after_loading(self, layer: Module) -> None: w2_weight = layer.w2_weight w2_weight_scale_inv = layer.w2_weight_scale_inv - if self.block_quant: + if self.quant_config.weight_block_size is not None: (w2_weight, w2_weight_scale_inv, w13_weight, w13_weight_scale_inv ) = self._maybe_pad_rocm_aiter_block_scaled_fused_moe_weights( w2_weight, From 57e5540cca6cc8d50bf5694c9c3ab2287acbd0e4 Mon Sep 17 00:00:00 2001 From: Qiang Li Date: Thu, 5 Jun 2025 20:20:54 +0000 Subject: [PATCH 5/5] yapf Signed-off-by: Qiang Li --- vllm/model_executor/layers/quantization/fp8.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e872febc3756..79043acffd06 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -731,14 +731,15 @@ def process_weights_after_loading(self, layer: Module) -> None: w2_weight_scale_inv = layer.w2_weight_scale_inv if self.quant_config.weight_block_size is not None: - (w2_weight, w2_weight_scale_inv, w13_weight, w13_weight_scale_inv - ) = self._maybe_pad_rocm_aiter_block_scaled_fused_moe_weights( - w2_weight, - w2_weight_scale_inv, - w13_weight, - w13_weight_scale_inv, - block_n=self.quant_config.weight_block_size[0], - block_k=self.quant_config.weight_block_size[1]) + (w2_weight, w2_weight_scale_inv, w13_weight, + w13_weight_scale_inv + ) = self._maybe_pad_rocm_aiter_block_scaled_fused_moe_weights( + w2_weight, + w2_weight_scale_inv, + w13_weight, + w13_weight_scale_inv, + block_n=self.quant_config.weight_block_size[0], + block_k=self.quant_config.weight_block_size[1]) # torch.compile() cannot use Parameter subclasses. layer.w13_weight = Parameter(w13_weight, requires_grad=False)