Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_ascend/quantization/w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,6 @@ def process_weights_after_loading(self, layer):
requires_grad=False)
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
94 changes: 86 additions & 8 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,82 @@
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2


def apply_mlp_decode(hidden_states_wrapper: List[torch.Tensor],
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1) -> torch.Tensor:
"""
apply MLP: gate_up_proj -> swiglu -> down_proj

Args:
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
w1: expert weights1 with shape
(num_experts, hidden_size, intermediate_size * 2)
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
w2: expert weights2 with shape
(num_experts, intermediate_size, hidden_size)
w2_scale: weights2 scale with shape (num_experts, hidden_size)
group_list: number of tokens for each expert, follow cumsum mode, and
with shape (num_experts).
transpose_weight:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's transpose_weight, looks like dynamic_scale and dynamic_scale is missing.

w1: (num_experts, intermediate_size * 2, hidden_size) ->
(num_experts, hidden_size, intermediate_size * 2)
w2: (num_experts, hidden_size, intermediate_size) ->
(num_experts, intermediate_size, hidden_size)

Returns:
hidden_states: output hidden states after MLP.
"""

assert len(hidden_states_wrapper) == 1
hidden_states = hidden_states_wrapper.pop()
if dynamic_scale is None:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
else:
pertoken_scale = dynamic_scale

# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]

# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
Comment on lines +82 to +83
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using this operator with graph mode causes the process to freeze. The cause is currently unknown.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using this operator with graph mode causes the process to freeze. The cause is currently unknown.

Could you please summarize the settings where process being forzen? We test the code with RC1 CANN & PTA and it works.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a PD separation scenario, the Decode node, TP2 DP16, sometimes gets stuck during compilation or execution when graph mode is enabled.

weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)

# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]
return hidden_states


def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
w1: torch.Tensor,
w1_scale: torch.Tensor,
Expand Down Expand Up @@ -159,13 +235,13 @@ def fused_experts_with_mc2(
hidden_states_wrapper = [expand_x]
del expand_x

down_out_list = apply_mlp(hidden_states_wrapper,
w1,
w1_scale,
w2,
w2_scale,
expert_token_nums,
dynamic_scale=dynamic_scale)
down_out_list = apply_mlp_decode(hidden_states_wrapper,
w1,
w1_scale,
w2,
w2_scale,
expert_token_nums,
dynamic_scale=dynamic_scale)

# moeCombine
kwargs = {
Expand Down Expand Up @@ -628,7 +704,7 @@ def apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale,
w1_scale=layer.w13_weight_scale_fp32,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
Expand Down Expand Up @@ -665,6 +741,8 @@ def process_weights_after_loading(self, layer):
1, 2).contiguous()
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
torch.float32)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
Expand Down