diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index f258f8e709e..8adde876a0b 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -226,8 +226,8 @@ def test_fused_experts_method(self, mock_unified_apply_mlp, w2 = w2.contiguous() result = comm_impl.fused_experts(hidden_states=hidden_states, - w1=w1, - w2=w2, + w1=[w1], + w2=[w2], topk_weights=topk_weights, topk_ids=topk_ids, activation="silu") diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index d200fa6ca00..47a99d1b10d 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -44,11 +44,22 @@ def __init__(self, model, **args): self.init_redundancy_expert = get_ascend_config( ).init_redundancy_expert + for i in range(self.num_dense_layers, + self.model.config.num_hidden_layers): + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = \ + self.model.model.layers[i].mlp.experts.w13_weight_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = \ + self.model.model.layers[i].mlp.experts.w2_weight_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = \ + self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \ + self.model.model.layers[i].mlp.experts.w2_weight_scale_list # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here if self.model.quant_config is not None: self.expert_weight_names = [ - "w13_weight", "w2_weight", "w13_weight_scale", - "w13_weight_offset", "w2_weight_scale", "w2_weight_offset" + "w13_weight_list", "w2_weight_list", + "w13_weight_scale_fp32_list", "w13_weight_offset", + "w2_weight_scale_list", "w2_weight_offset" ] else: self.expert_weight_names = ["w13_weight", "w2_weight"] @@ -84,9 +95,14 @@ def init_buffer_tensor(self, num_buffer_tensor): for name in self.expert_weight_names: complete_name = "model.layers." + str( self.num_dense_layers) + ".mlp.experts." + name - expert_tensor = self.param_dict[complete_name].data[0] - if name in ["w13_weight", "w2_weight"]: + if name in [ + "w13_weight_list", "w2_weight_list", + "w13_weight_scale_fp32_list", "w2_weight_scale_list" + ]: + expert_tensor = self.param_dict[complete_name][0] expert_tensor = expert_tensor.clone() + else: + expert_tensor = self.param_dict[complete_name][0].data[0] buffer_tensor = torch.empty_like(expert_tensor) self.buffer_tensor_list[buffer_id].append(buffer_tensor) @@ -97,12 +113,23 @@ def init_expert_param_per_layer(self): layer_idx = self.num_dense_layers + moe_layer_id self.expert_param_per_layer[layer_idx] = list() for local_expert_id in range(num_local_expert): - self.expert_param_per_layer[layer_idx].append([ - self.param_dict["model.layers." + str(layer_idx) + - ".mlp.experts." + - name].data[local_expert_id] - for name in self.expert_weight_names - ]) + per_expert_param = list() + for name in self.expert_weight_names: + if name in [ + "w13_weight_list", "w2_weight_list", + "w13_weight_scale_fp32_list", + "w2_weight_scale_list" + ]: + per_expert_param.append( + self.param_dict["model.layers." + str(layer_idx) + + ".mlp.experts." + + name][local_expert_id]) + else: + per_expert_param.append( + self.param_dict["model.layers." + str(layer_idx) + + ".mlp.experts." + + name][0].data[local_expert_id]) + self.expert_param_per_layer[layer_idx].append(per_expert_param) def get_rank_expert_workload(self) -> torch.Tensor: self.moe_load = self.model.get_all_moe_loads() diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index c48ce1a49be..802dbe5d66e 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -83,8 +83,8 @@ def finalize(self, def fused_experts( self, hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, + w1: torch.Tensor | list[torch.Tensor], + w2: torch.Tensor | list[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", @@ -93,8 +93,8 @@ def fused_experts( use_int4_w4a8: bool = False, global_num_experts: Optional[int] = None, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, + w1_scale: Optional[list[torch.Tensor]] = None, + w2_scale: Optional[list[torch.Tensor]] = None, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, # For TorchAir graph diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 07ba732f199..13e1efc0acd 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -23,7 +23,11 @@ from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.utils import (AscendDeviceType, dispose_tensor, - get_ascend_device_type) + enable_custom_op, get_ascend_device_type) + + +def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): + return fusion and dynamic_eplb and enable_custom_op() def cumsum_group_list(group_list: torch.Tensor, @@ -55,10 +59,10 @@ def cumsum_group_list(group_list: torch.Tensor, def quant_apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, + w1: list[torch.Tensor], + w1_scale: list[torch.Tensor], + w2: list[torch.Tensor], + w2_scale: list[torch.Tensor], group_list: torch.Tensor, group_list_type: int = 1, dynamic_scale: torch.Tensor = None, @@ -79,7 +83,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, quantized_hidden_states = hidden_states bias1, bias2 = None, None - _output_dtype = w2_scale.dtype + _output_dtype = w2_scale[0].dtype weight_prefetch_method = get_forward_context().weight_prefetch_method if weight_prefetch_method: @@ -87,23 +91,34 @@ def quant_apply_mlp(hidden_states: torch.Tensor, hidden_states) is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 if w1_scale_bias is None and is_mc2: - if fusion and not dynamic_eplb: + if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): + # gmm1: gate_up_proj & act_fn: swiglu + hidden_states, swiglu_out_scale, _ = ( + torch.ops._C_ascend. + grouped_matmul_swiglu_quant_weight_nz_tensor_list( + x=hidden_states, + weight=w1, + weight_scale=w1_scale, + x_scale=pertoken_scale, + group_list=cumsum_group_list(group_list, group_list_type), + )) + elif fusion and not dynamic_eplb: # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( x=hidden_states, - weight=w1, + weight=w1[0], group_list=cumsum_group_list(group_list, group_list_type), - weight_scale=w1_scale, + weight_scale=w1_scale[0], x_scale=pertoken_scale) if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) else: - if w1_scale.dtype != torch.float32: - w1_scale = w1_scale.to(torch.float32) + if w1_scale[0].dtype != torch.float32: + w1_scale[0] = w1_scale[0].to(torch.float32) # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w1], + weight=w1, split_item=3, group_list_type=group_list_type, group_type=0, @@ -126,14 +141,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor, # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w2], - scale=[w2_scale], + weight=w2, + scale=w2_scale, per_token_scale=[swiglu_out_scale], split_item=2, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=w2_scale.dtype)[0] + output_dtype=w2_scale[0].dtype)[0] else: if w1_scale_bias is not None: if group_list_type == 0: @@ -146,23 +161,36 @@ def quant_apply_mlp(hidden_states: torch.Tensor, # TODO w4a8 scene: dynamic acquisition of dtype in the future _output_dtype = torch.bfloat16 - if fusion and not dynamic_eplb: + if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): + # gmm1: gate_up_proj & act_fn: swiglu + hidden_states, swiglu_out_scale, _ = ( + torch.ops._C_ascend. + grouped_matmul_swiglu_quant_weight_nz_tensor_list( + x=hidden_states, + weight=w1, + weight_scale=w1_scale, + x_scale=pertoken_scale, + group_list=cumsum_group_list(group_list, group_list_type), + bias=bias1, + )) + elif fusion and not dynamic_eplb: # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( x=hidden_states, - weight=w1, + weight=w1[0], bias=bias1, group_list=cumsum_group_list(group_list, group_list_type), - weight_scale=w1_scale, + weight_scale=w1_scale[0], x_scale=pertoken_scale) if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) else: + w1_scale[0] = w1_scale[0].to(w2_scale[0].dtype) # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w1], - scale=[w1_scale.to(w2_scale.dtype)], + weight=w1, + scale=w1_scale, bias=bias1, per_token_scale=[pertoken_scale], split_item=2, @@ -179,8 +207,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w2], - scale=[w2_scale], + weight=w2, + scale=w2_scale, bias=bias2, per_token_scale=[swiglu_out_scale], split_item=2, @@ -232,11 +260,11 @@ def unquant_apply_mlp(hidden_states: torch.Tensor, def unified_apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, + w1: torch.Tensor | list[torch.Tensor], + w2: torch.Tensor | list[torch.Tensor], group_list: torch.Tensor, + w1_scale: Optional[list[torch.Tensor]] = None, + w2_scale: Optional[list[torch.Tensor]] = None, dynamic_scale: torch.Tensor = None, group_list_type: int = 1, w1_scale_bias: torch.Tensor = None, @@ -247,6 +275,7 @@ def unified_apply_mlp(hidden_states: torch.Tensor, need_trans: bool = True, dynamic_eplb: bool = False) -> torch.Tensor: if with_quant: + assert w1_scale is not None and w2_scale is not None return quant_apply_mlp(hidden_states=hidden_states, w1=w1, w1_scale=w1_scale, diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index c7f1dfabb86..a73050c3123 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -379,10 +379,10 @@ def apply( moe_comm_method = get_forward_context().moe_comm_method return moe_comm_method.fused_experts( hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, + w1=[layer.w13_weight], + w2=[layer.w2_weight], + w1_scale=[layer.w13_weight_scale], + w2_scale=[layer.w2_weight_scale], w1_scale_bias=layer.w13_scale_bias, w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 6b7d6b0875c..386c856f9ef 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -234,13 +234,24 @@ def apply( topk_weights = topk_weights.to(self.in_dtype) moe_comm_method = get_forward_context().moe_comm_method + if self.dynamic_eplb: + w1 = layer.w13_weight_list + w1_scale = layer.w13_weight_scale_fp32_list + w2 = layer.w2_weight_list + w2_scale = layer.w2_weight_scale_list + else: + w1 = [layer.w13_weight] + w1_scale = [layer.w13_weight_scale_fp32] + w2 = [layer.w2_weight] + w2_scale = [layer.w2_weight_scale] + return moe_comm_method.fused_experts( hidden_states=x, pertoken_scale=pertoken_scale, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale_fp32, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, topk_weights=topk_weights, topk_ids=topk_ids, use_int8_w8a8=True, @@ -272,3 +283,25 @@ def process_weights_after_loading(self, layer): layer.w2_weight_scale.data.shape[0], -1) layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( layer.w2_weight_offset.data.shape[0], -1) + if self.dynamic_eplb: + layer.w13_weight_list = [ + weight.clone() + for weight in layer.w13_weight.data.unbind(dim=0) + ] + layer.w2_weight_list = [ + weight.clone() for weight in layer.w2_weight.data.unbind(dim=0) + ] + layer.w13_weight_scale_fp32_list = [ + weight.clone() + for weight in layer.w13_weight_scale.data.unbind(dim=0) + ] + layer.w2_weight_scale_list = [ + weight.clone() + for weight in layer.w2_weight_scale.data.unbind(dim=0) + ] + del layer.w13_weight + del layer.w2_weight + del layer.w13_weight_scale + del layer.w13_weight_scale_fp32 + del layer.w2_weight_scale + torch.npu.empty_cache()