diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 80a481c39b8..848f1aa2545 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -437,8 +437,7 @@ def full_graph_pa( block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, out=output) - update_graph_params_workspaces(num_tokens, - weak_ref_tensors(workspace)) + update_graph_params_workspaces(num_tokens, workspace) # Handle graph capturing mode stream = torch_npu.npu.current_stream() diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 3f28a3a3632..72cf925d568 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -162,6 +162,13 @@ def __call__(self, *args, **kwargs): # any other acl graph. output = weak_ref_tensors(output) + # here we always use weak ref for the workspaces + # to save memory + global _graph_params + global _draft_graph_params + weak_ref_workspaces(_graph_params) + weak_ref_workspaces(_draft_graph_params) + # here we always use weak ref for the output # to save memory entry.output = weak_ref_tensors(output) @@ -195,6 +202,16 @@ def __call__(self, *args, **kwargs): return entry.output +def weak_ref_workspaces(params): + if params is None: + return + for num_tokens in params.workspaces: + if params.workspaces[num_tokens] is None: + continue + params.workspaces[num_tokens] = weak_ref_tensors( + params.workspaces[num_tokens]) + + def _update_attn_pa_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args @@ -523,7 +540,7 @@ def set_graph_params(aclgraph_capture_sizes: list[int]): def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor): global _graph_params if _graph_params is not None: - _graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace) + _graph_params.workspaces[num_tokens] = workspace def get_graph_params():