Skip to content

Commit e3c7f71

Browse files
authored
[Perf] Refactor tensor disposal logic to reduce memory usage (#966)
### What this PR does / why we need it? 1. In previous PRs #580 #784, I saved GPU memory by promptly deleting unnecessary tensors. For tensors passed from upper-layer functions, I used a list container to transfer the parameter and then popped the tensor from the list within the inner function to achieve deletion. Recently, I discovered a better implementation in sglang—the `dispose_tensor` function and I recommend adopting this approach. 2. Dispose `hidden_states` and `residual` from the previous layer once they're no longer used. 3. Avoid to generate `self.inputs_embeds` in `ModelRunnerV1` in non-multimodal scenarios. With the aforementioned optimizations, using the DeepSeek-R1-W8A8 model under the conditions of `TP=16` and `max-model-len=32768`, we can save 1.3GB of npu memory. **Reference**: sgl-project/sglang#6147 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? --------- Signed-off-by: ApsarasX <[email protected]>
1 parent 6eddbd2 commit e3c7f71

File tree

4 files changed

+30
-23
lines changed

4 files changed

+30
-23
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
import vllm_ascend.envs as envs_ascend
6969
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7070
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
71+
from vllm_ascend.utils import dispose_tensor
7172

7273
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
7374

@@ -518,8 +519,14 @@ def forward(
518519
residual = hidden_states
519520
hidden_states = self.input_layernorm(hidden_states)
520521
else:
522+
previous_hidden_states, previous_residual = hidden_states, residual
521523
hidden_states, residual = self.input_layernorm(
522524
hidden_states, residual)
525+
# Dispose hidden_states and residual from the previous layer
526+
# to save npu memory because they're no longer used.
527+
dispose_tensor(previous_hidden_states)
528+
dispose_tensor(previous_residual)
529+
523530
hidden_states = self.self_attn(
524531
positions=positions,
525532
hidden_states=hidden_states,

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Any, Callable, Dict, List, Optional
18+
from typing import Any, Callable, Dict, Optional
1919

2020
import torch
2121
import torch.distributed as dist
@@ -25,11 +25,12 @@
2525
import vllm_ascend.envs as envs_ascend
2626
from vllm_ascend.distributed.parallel_state import get_ep_group
2727
from vllm_ascend.ops.fused_moe import select_experts
28+
from vllm_ascend.utils import dispose_tensor
2829

2930
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
3031

3132

32-
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
33+
def apply_mlp(hidden_states: torch.Tensor,
3334
w1: torch.Tensor,
3435
w1_scale: torch.Tensor,
3536
w2: torch.Tensor,
@@ -41,7 +42,7 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
4142
apply MLP: gate_up_proj -> swiglu -> down_proj
4243
4344
Args:
44-
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
45+
hidden_states: input hidden states with shape (num_tokens, hidden_size).
4546
w1: expert weights1 with shape
4647
(num_experts, hidden_size, intermediate_size * 2)
4748
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
@@ -60,11 +61,13 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
6061
hidden_states: output hidden states after MLP.
6162
"""
6263

63-
assert len(hidden_states_wrapper) == 1
64-
hidden_states = hidden_states_wrapper.pop()
6564
if dynamic_scale is None:
65+
unquantized_hidden_states = hidden_states
6666
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
6767
hidden_states)
68+
# Dispose the original unquantized hidden states
69+
# to save npu memory because they're no longer used.
70+
dispose_tensor(unquantized_hidden_states)
6871
else:
6972
pertoken_scale = dynamic_scale
7073

@@ -155,11 +158,8 @@ def fused_experts_with_mc2(
155158
if quant_mode == 0:
156159
dynamic_scale = None
157160

158-
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
159-
hidden_states_wrapper = [expand_x]
160-
del expand_x
161-
162-
down_out_list = apply_mlp(hidden_states_wrapper,
161+
# `expand_x` will be disposed in the `apply_mlp` function
162+
down_out_list = apply_mlp(expand_x,
163163
w1,
164164
w1_scale,
165165
w2,
@@ -281,10 +281,8 @@ def fused_experts_with_all2all(
281281
expert_tokens = expert_tokens.to(torch.int64)
282282
group_list_type = 0
283283

284-
hidden_states_wrapper = [hidden_states]
285-
del hidden_states
286-
287-
hidden_states = apply_mlp(hidden_states_wrapper,
284+
# `hidden_states` will be disposed in the `apply_mlp` function
285+
hidden_states = apply_mlp(hidden_states,
288286
w1,
289287
w1_scale,
290288
w2,
@@ -399,11 +397,8 @@ def fused_experts(hidden_states: torch.Tensor,
399397
expert_tokens = expert_tokens.to(torch.int64)
400398
group_list_type = 0
401399

402-
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
403-
hidden_states_wrapper = [hidden_states]
404-
del hidden_states
405-
406-
hidden_states = apply_mlp(hidden_states_wrapper,
400+
# `hidden_states` will be disposed in the `apply_mlp` function
401+
hidden_states = apply_mlp(hidden_states,
407402
w1,
408403
w1_scale,
409404
w2,

vllm_ascend/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
169169
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
170170
vllm_config.model_config.architectures[0], num_hidden_layers,
171171
len(original_sizes))
172+
173+
174+
def dispose_tensor(x: torch.Tensor):
175+
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
240240
device="cpu",
241241
pin_memory=True)
242242

243-
self.inputs_embeds = torch.zeros(
244-
(self.max_num_tokens, self.hidden_size),
245-
dtype=self.dtype,
246-
device=self.device)
243+
if self.is_multimodal_model:
244+
self.inputs_embeds = torch.zeros(
245+
(self.max_num_tokens, self.hidden_size),
246+
dtype=self.dtype,
247+
device=self.device)
247248

248249
# OPTIMIZATION: Cache the tensors rather than creating them every step.
249250
self.arange_np: npt.NDArray[np.int32] = np.arange(max(

0 commit comments

Comments
 (0)