Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor

VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2

Expand Down Expand Up @@ -518,8 +519,14 @@ def forward(
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
previous_hidden_states, previous_residual = hidden_states, residual
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
# Dispose hidden_states and residual from the previous layer
# to save npu memory because they're no longer used.
dispose_tensor(previous_hidden_states)
dispose_tensor(previous_residual)

hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
Expand Down
33 changes: 14 additions & 19 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, Optional

import torch
import torch.distributed as dist
Expand All @@ -25,11 +25,12 @@
import vllm_ascend.envs as envs_ascend
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.utils import dispose_tensor

VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2


def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
def apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
Expand All @@ -41,7 +42,7 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
apply MLP: gate_up_proj -> swiglu -> down_proj

Args:
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
hidden_states: input hidden states with shape (num_tokens, hidden_size).
w1: expert weights1 with shape
(num_experts, hidden_size, intermediate_size * 2)
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
Expand All @@ -60,11 +61,13 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
hidden_states: output hidden states after MLP.
"""

assert len(hidden_states_wrapper) == 1
hidden_states = hidden_states_wrapper.pop()
if dynamic_scale is None:
unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# Dispose the original unquantized hidden states
# to save npu memory because they're no longer used.
dispose_tensor(unquantized_hidden_states)
else:
pertoken_scale = dynamic_scale

Expand Down Expand Up @@ -155,11 +158,8 @@ def fused_experts_with_mc2(
if quant_mode == 0:
dynamic_scale = None

# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
hidden_states_wrapper = [expand_x]
del expand_x

down_out_list = apply_mlp(hidden_states_wrapper,
# `expand_x` will be disposed in the `apply_mlp` function
down_out_list = apply_mlp(expand_x,
w1,
w1_scale,
w2,
Expand Down Expand Up @@ -281,10 +281,8 @@ def fused_experts_with_all2all(
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0

hidden_states_wrapper = [hidden_states]
del hidden_states

hidden_states = apply_mlp(hidden_states_wrapper,
# `hidden_states` will be disposed in the `apply_mlp` function
hidden_states = apply_mlp(hidden_states,
w1,
w1_scale,
w2,
Expand Down Expand Up @@ -399,11 +397,8 @@ def fused_experts(hidden_states: torch.Tensor,
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0

# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
hidden_states_wrapper = [hidden_states]
del hidden_states

hidden_states = apply_mlp(hidden_states_wrapper,
# `hidden_states` will be disposed in the `apply_mlp` function
hidden_states = apply_mlp(hidden_states,
w1,
w1_scale,
w2,
Expand Down
4 changes: 4 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
vllm_config.model_config.architectures[0], num_hidden_layers,
len(original_sizes))


def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
9 changes: 5 additions & 4 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
device="cpu",
pin_memory=True)

self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
if self.is_multimodal_model:
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)

# OPTIMIZATION: Cache the tensors rather than creating them every step.
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
Expand Down