diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index ae9dd46fe2e..99de0bf8ce2 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -110,5 +110,6 @@ def process_weights_after_loading(self, layer): requires_grad=False) if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 5d2b442cf18..a5635d28f45 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -29,6 +29,82 @@ VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 +def apply_mlp_decode(hidden_states_wrapper: List[torch.Tensor], + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj + + Args: + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + w2_scale: weights2 scale with shape (num_experts, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) + + Returns: + hidden_states: output hidden states after MLP. + """ + + assert len(hidden_states_wrapper) == 1 + hidden_states = hidden_states_wrapper.pop() + if dynamic_scale is None: + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + else: + pertoken_scale = dynamic_scale + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32)[0] + + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale, + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=group_list, + activate_left=True, + quant_mode=1, + ) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=w2_scale.dtype)[0] + return hidden_states + + def apply_mlp(hidden_states_wrapper: List[torch.Tensor], w1: torch.Tensor, w1_scale: torch.Tensor, @@ -159,13 +235,13 @@ def fused_experts_with_mc2( hidden_states_wrapper = [expand_x] del expand_x - down_out_list = apply_mlp(hidden_states_wrapper, - w1, - w1_scale, - w2, - w2_scale, - expert_token_nums, - dynamic_scale=dynamic_scale) + down_out_list = apply_mlp_decode(hidden_states_wrapper, + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale) # moeCombine kwargs = { @@ -628,7 +704,7 @@ def apply( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale, + w1_scale=layer.w13_weight_scale_fp32, w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, @@ -665,6 +741,8 @@ def process_weights_after_loading(self, layer): 1, 2).contiguous() layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data.shape[0], -1) + layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( + torch.float32) layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( layer.w13_weight_offset.data.shape[0], -1) layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(