Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 5 additions & 21 deletions vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,8 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens

# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
# might encounter a bigger workspace, while currently we use max_model_len to
# calculate max workspace in capturing. So additional get_workspace is added
# here to avoid such bugs.
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
workspace = torch_npu._npu_paged_attention_get_workspace(
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
Expand All @@ -253,18 +246,9 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace)
out=output,
workspace=graph_params.workspaces.get(runtime_shape),
)
torch.npu.graph_task_update_end(update_stream)

event.record(update_stream)
Expand Down
15 changes: 13 additions & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
using_paged_attention)
# yapf conflicts with isort for this block
# yapf: disable
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
Expand Down Expand Up @@ -127,6 +128,8 @@
if get_ascend_device_type() == AscendDeviceType._310P:
torch_npu.npu.set_compile_mode(jit_compile=False)

SEQ_LEN_WITH_MAX_PA_WORKSPACE = 6144


@dataclass
class GraphCaptureContext:
Expand Down Expand Up @@ -1926,6 +1929,7 @@ def _build_dummy_attn_metadata(
num_scheduled_tokens: np.ndarray,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
is_graph_capturing: bool = False,
) -> Optional[dict[str, Any]]:
attn_metadata: Optional[dict[str, Any]] = None

Expand All @@ -1935,7 +1939,13 @@ def _build_dummy_attn_metadata(

attn_metadata = {}

seq_lens = max_query_len
# The reason why we use a fixed seq_len rather than max_query_len is that
# _npu_paged_attention_get_workspace only returns max workspace with specific
# seq_lens. We use this seq_len only when capturing graph, and still use max_query_len
# in inference. This will be removed once npu_fused_infer_attention_score
# outperforms _npu_paged_attention on all cases.
seq_lens = SEQ_LEN_WITH_MAX_PA_WORKSPACE if is_graph_capturing and using_paged_attention(
num_tokens, self.vllm_config) else max_query_len
self.seq_lens.np[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()
Expand Down Expand Up @@ -2186,6 +2196,7 @@ def _dummy_run(
max_query_len=max_query_len,
aclgraph_runtime_mode=cudagraph_runtime_mode,
force_attention=force_attention,
is_graph_capturing=is_graph_capturing,
num_scheduled_tokens=num_scheduled_tokens,
)

Expand Down