Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)

if (self.tp_size > 1 and self.enable_mc2 and not is_prefill):
chunks = torch.chunk(hidden_states,
get_tp_group().world_size,
Expand All @@ -248,8 +245,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
else:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

if shared_output is not None:
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
final_hidden_states = final_hidden_states + shared_output

return final_hidden_states.view(num_tokens, hidden_dim)
Expand Down
86 changes: 49 additions & 37 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

import os
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional

import torch
import torch_npu
Expand All @@ -25,7 +25,7 @@
from vllm_ascend.ops.fused_moe import select_experts


def apply_mlp(x: torch.Tensor,
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
Expand All @@ -37,7 +37,7 @@ def apply_mlp(x: torch.Tensor,
apply MLP: gate_up_proj -> swiglu -> down_proj

Args:
x: input hidden states with shape (num_tokens, hidden_size).
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
w1: expert weights1 with shape
(num_experts, hidden_size, intermediate_size * 2)
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
Expand All @@ -56,23 +56,27 @@ def apply_mlp(x: torch.Tensor,
hidden_states: output hidden states after MLP.
"""

assert len(hidden_states_wrapper) == 1
hidden_states = hidden_states_wrapper.pop()
if dynamic_scale is None:
h, pertoken_scale = torch_npu.npu_dynamic_quant(x)
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
else:
h = x
pertoken_scale = dynamic_scale

# gmm1: gate_up_proj
gate_up_out = torch_npu.npu_grouped_matmul(x=[h],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]

swiglu_out, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=gate_up_out,
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]

# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
Expand All @@ -83,17 +87,18 @@ def apply_mlp(x: torch.Tensor,
quant_mode=1,
)

# down_proj
down_out = torch_npu.npu_grouped_matmul(x=[swiglu_out],
weight=[w2],
scale=[w2_scale],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]
return down_out
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]
return hidden_states


def fused_experts_with_mc2(
Expand Down Expand Up @@ -152,7 +157,11 @@ def fused_experts_with_mc2(
if quant_mode == 0:
dynamic_scale = None

down_out_list = apply_mlp(expand_x,
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with this quick fix. you have describe the case more clear later. For example, this copy to list action aims to deal with the lifecycle for the tensor so that the memory usage can be improved.

hidden_states_wrapper = [expand_x]
del expand_x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There still a lot of del action everywhere. We can create a new func like release_tenstor to make the del aciton more clear for developer


down_out_list = apply_mlp(hidden_states_wrapper,
w1,
w1_scale,
w2,
Expand Down Expand Up @@ -246,7 +255,7 @@ def fused_experts(hidden_states: torch.Tensor,
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
expert_tokens = token_counts[:num_experts]
# Rearrange hidden_states
sorted_hidden_states = hidden_states[sorted_token_indices]
hidden_states = hidden_states[sorted_token_indices]
group_list_type = 1
else:
row_idx_len = num_tokens * top_k
Expand All @@ -255,19 +264,22 @@ def fused_experts(hidden_states: torch.Tensor,
dtype=torch.int32,
device=topk_weights.device).view(
top_k, -1).permute(1, 0).contiguous()
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
del hidden_states

expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0

down_out_list = apply_mlp(sorted_hidden_states,
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
hidden_states_wrapper = [hidden_states]
del hidden_states

hidden_states = apply_mlp(hidden_states_wrapper,
w1,
w1_scale,
w2,
Expand All @@ -276,31 +288,31 @@ def fused_experts(hidden_states: torch.Tensor,
group_list_type=group_list_type)

if expert_map is not None:
down_out_list.mul_(sorted_weights.unsqueeze(1))
hidden_states.mul_(sorted_weights.unsqueeze(1))
final_hidden_states = torch.zeros(*original_shape,
device=hidden_states.device,
device=device,
dtype=dtype)

num_valid_tokens = mask.sum()
valid_token_mask = torch.arange(
0, sorted_token_indices.shape[0],
device=device).unsqueeze(1) < num_valid_tokens
down_out_list = down_out_list.masked_fill_(~valid_token_mask,
hidden_states = hidden_states.masked_fill_(~valid_token_mask,
0).to(dtype)
final_hidden_states.index_add_(0, sorted_token_indices, down_out_list)
final_hidden_states.index_add_(0, sorted_token_indices, hidden_states)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing(
down_out_list,
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
del down_out_list

if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states
Expand Down