Skip to content

Commit e5d98f5

Browse files
committed
[Refactor](attention) Simplify forward and use configurable block size
Signed-off-by: Yizhou Liu <[email protected]>
1 parent 78c2eee commit e5d98f5

File tree

1 file changed

+44
-116
lines changed

1 file changed

+44
-116
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 44 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def __init__(
256256

257257
vllm_config = get_current_vllm_config()
258258
self.full_graph = vllm_config.compilation_config.full_cuda_graph
259+
self.block_size = vllm_config.cache_config.block_size
259260

260261
def forward(
261262
self,
@@ -268,21 +269,7 @@ def forward(
268269
output: Optional[torch.Tensor] = None,
269270
trace_flag: bool = True,
270271
) -> torch.Tensor:
271-
"""Forward pass with Ascend attention.
272-
Args:
273-
query: shape = [batch_size, seq_len, num_heads * head_size]
274-
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
275-
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
276-
kv_cache: shape = [2, num_blocks, block_size,
277-
num_kv_heads, head_size]
278-
key_cache = [num_blocks, block_size,
279-
num_kv_heads, head_size]
280-
value_cache = [num_blocks, block_size,
281-
num_kv_heads, head_size]
282-
attn_metadata: Metadata for attention.
283-
Returns:
284-
shape = [batch_size * seq_len, num_heads, head_size]
285-
"""
272+
"""Forward pass with Ascend attention."""
286273
num_tokens = query.shape[0]
287274
if output is None:
288275
output = torch.empty(num_tokens,
@@ -365,136 +352,77 @@ def forward(
365352
if self.full_graph:
366353
graph_params = get_graph_params()
367354
q = query.view(num_tokens, -1, self.hidden_size)
368-
k = self.key_cache.view(-1, 128,
355+
k = self.key_cache.view(-1, self.block_size,
369356
self.num_kv_heads * self.head_size)
370357
v = self.value_cache.view(
371-
-1, 128, self.num_kv_heads * self.head_size)
358+
-1, self.block_size,
359+
self.num_kv_heads * self.head_size)
372360
actual_seq_lens = attn_metadata.seq_lens_list
361+
attn_args = {
362+
"query": q,
363+
"key": k,
364+
"value": v,
365+
"actual_seq_lengths_kv": actual_seq_lens,
366+
"block_table": attn_metadata.block_tables,
367+
"num_heads": self.num_heads,
368+
"scale": self.scale,
369+
"input_layout": "BSH",
370+
"num_key_value_heads": self.num_kv_heads,
371+
"block_size": self.block_size,
372+
}
373+
374+
# Prepare tensors for attention output
375+
# TODO: Refactor this to step-level instead of layer-level
376+
attn_output = torch.empty(num_tokens,
377+
1,
378+
self.hidden_size,
379+
dtype=output.dtype,
380+
device=output.device)
381+
softmax_lse = torch.empty(num_tokens,
382+
dtype=output.dtype,
383+
device=output.device)
384+
385+
# Get workspace from cache or calculate it if not present.
386+
workspace = graph_params.workspaces.get(num_tokens)
387+
if workspace is None:
388+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
389+
**attn_args)
390+
graph_params.workspaces[num_tokens] = workspace
373391

374392
forward_context = get_forward_context()
375393
if not forward_context.capturing:
376-
workspace = (
377-
torch_npu.
378-
_npu_fused_infer_attention_score_get_max_workspace(
379-
q,
380-
k,
381-
v,
382-
actual_seq_lengths_kv=actual_seq_lens,
383-
block_table=attn_metadata.block_tables,
384-
num_heads=self.num_heads,
385-
scale=self.scale,
386-
input_layout="BSH",
387-
num_key_value_heads=self.num_kv_heads,
388-
block_size=128,
389-
))
390-
graph_params.workspaces[num_tokens] = workspace
391-
softmax_lse = torch.empty(num_tokens,
392-
dtype=output.dtype,
393-
device=output.device)
394-
attn_output = torch.empty(num_tokens,
395-
1,
396-
self.hidden_size,
397-
dtype=output.dtype,
398-
device=output.device)
394+
# Execute attention kernel directly in non-capturing mode
399395
torch.ops.npu.npu_fused_infer_attention_score.out(
400-
q,
401-
k,
402-
v,
403396
workspace=workspace,
404-
actual_seq_lengths_kv=actual_seq_lens,
405-
block_table=attn_metadata.block_tables,
406-
num_heads=self.num_heads,
407-
scale=self.scale,
408-
input_layout="BSH",
409-
num_key_value_heads=self.num_kv_heads,
410-
block_size=128,
411397
out=[attn_output, softmax_lse],
412-
)
413-
output[:, :, :] = attn_output.view(
414-
num_tokens, self.num_heads,
415-
self.head_size)[:, :, :]
398+
**attn_args)
416399
else:
400+
# Handle graph capturing mode
417401
stream = torch_npu.npu.current_stream()
418-
workspace = graph_params.workspaces.get(num_tokens)
419-
if workspace is None:
420-
workspace = (
421-
torch_npu.
422-
_npu_fused_infer_attention_score_get_max_workspace(
423-
q,
424-
k,
425-
v,
426-
actual_seq_lengths_kv=actual_seq_lens,
427-
block_table=attn_metadata.block_tables,
428-
num_heads=self.num_heads,
429-
scale=self.scale,
430-
input_layout="BSH",
431-
num_key_value_heads=self.num_kv_heads,
432-
block_size=128,
433-
))
434-
graph_params.workspaces[num_tokens] = workspace
435-
softmax_lse = torch.empty(num_tokens,
436-
dtype=output.dtype,
437-
device=output.device)
438402

439403
event = torch.npu.ExternalEvent()
440404
event.wait(stream)
441405
event.reset(stream)
442406
graph_params.events[num_tokens].append(event)
443407

444-
attn_output = torch.empty(num_tokens,
445-
1,
446-
self.hidden_size,
447-
dtype=output.dtype,
448-
device=output.device)
449408
graph_params.attn_params[num_tokens].append(
450409
(q, k, v, actual_seq_lens,
451410
attn_metadata.block_tables, self.num_heads,
452411
self.scale, self.num_kv_heads, attn_output,
453412
softmax_lse))
413+
454414
torch.npu.graph_task_group_begin(stream)
455415
torch.ops.npu.npu_fused_infer_attention_score.out(
456-
q,
457-
k,
458-
v,
459416
workspace=workspace,
460-
actual_seq_lengths_kv=actual_seq_lens,
461-
block_table=attn_metadata.block_tables,
462-
num_heads=self.num_heads,
463-
scale=self.scale,
464-
input_layout="BSH",
465-
num_key_value_heads=self.num_kv_heads,
466-
block_size=128,
467417
out=[attn_output, softmax_lse],
468-
)
418+
**attn_args)
469419
handle = torch.npu.graph_task_group_end(stream)
470-
output[:, :, :] = attn_output.view(
471-
num_tokens, self.num_heads,
472-
self.head_size)[:, :, :]
473420
graph_params.handles[num_tokens].append(handle)
474421

475-
# output = torch_npu.npu_incre_flash_attention(
476-
# q, k, v,
477-
# num_key_value_heads=self.num_kv_heads,
478-
# num_heads=self.num_heads,
479-
# actual_seq_lengths=attn_metadata.seq_lens,
480-
# scale_value=self.scale,
481-
# block_table=attn_metadata.block_tables,
482-
# input_layout="BSH",
483-
# block_size=128
484-
# )
485-
# attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
486-
# q,
487-
# k,
488-
# v,
489-
# actual_seq_lengths_kv=actual_seq_lens,
490-
# block_table=attn_metadata.block_tables,
491-
# num_heads=self.num_heads,
492-
# scale=self.scale,
493-
# input_layout="BSH",
494-
# num_key_value_heads=self.num_kv_heads,
495-
# block_size=128,
496-
# )
497-
# output[:, :, :] = attn_output.view(num_tokens, self.num_heads, self.head_size)[:, :, :]
422+
# Reshape output to match the expected format
423+
output.copy_(
424+
attn_output.view(num_tokens, self.num_heads,
425+
self.head_size))
498426
else:
499427
torch_npu._npu_paged_attention(
500428
query=query,

0 commit comments

Comments
 (0)