Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,7 @@ def full_graph_pa(
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
update_graph_params_workspaces(num_tokens, workspace)

# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
Expand Down
19 changes: 18 additions & 1 deletion vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ def __call__(self, *args, **kwargs):
# any other acl graph.
output = weak_ref_tensors(output)

# here we always use weak ref for the workspaces
# to save memory
global _graph_params
global _draft_graph_params
weak_ref_workspaces(_graph_params)
weak_ref_workspaces(_draft_graph_params)

# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
Expand Down Expand Up @@ -195,6 +202,16 @@ def __call__(self, *args, **kwargs):
return entry.output


def weak_ref_workspaces(params):
if params is None:
return
for num_tokens in params.workspaces:
if params.workspaces[num_tokens] is None:
continue
params.workspaces[num_tokens] = weak_ref_tensors(
params.workspaces[num_tokens])


def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
Expand Down Expand Up @@ -523,7 +540,7 @@ def set_graph_params(aclgraph_capture_sizes: list[int]):
def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor):
global _graph_params
if _graph_params is not None:
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
_graph_params.workspaces[num_tokens] = workspace


def get_graph_params():
Expand Down
Loading