@@ -256,6 +256,7 @@ def __init__(
256256
257257 vllm_config = get_current_vllm_config ()
258258 self .full_graph = vllm_config .compilation_config .full_cuda_graph
259+ self .block_size = vllm_config .cache_config .block_size
259260
260261 def forward (
261262 self ,
@@ -268,21 +269,7 @@ def forward(
268269 output : Optional [torch .Tensor ] = None ,
269270 trace_flag : bool = True ,
270271 ) -> torch .Tensor :
271- """Forward pass with Ascend attention.
272- Args:
273- query: shape = [batch_size, seq_len, num_heads * head_size]
274- key: shape = [batch_size, seq_len, num_kv_heads * head_size]
275- value: shape = [batch_size, seq_len, num_kv_heads * head_size]
276- kv_cache: shape = [2, num_blocks, block_size,
277- num_kv_heads, head_size]
278- key_cache = [num_blocks, block_size,
279- num_kv_heads, head_size]
280- value_cache = [num_blocks, block_size,
281- num_kv_heads, head_size]
282- attn_metadata: Metadata for attention.
283- Returns:
284- shape = [batch_size * seq_len, num_heads, head_size]
285- """
272+ """Forward pass with Ascend attention."""
286273 num_tokens = query .shape [0 ]
287274 if output is None :
288275 output = torch .empty (num_tokens ,
@@ -365,136 +352,77 @@ def forward(
365352 if self .full_graph :
366353 graph_params = get_graph_params ()
367354 q = query .view (num_tokens , - 1 , self .hidden_size )
368- k = self .key_cache .view (- 1 , 128 ,
355+ k = self .key_cache .view (- 1 , self . block_size ,
369356 self .num_kv_heads * self .head_size )
370357 v = self .value_cache .view (
371- - 1 , 128 , self .num_kv_heads * self .head_size )
358+ - 1 , self .block_size ,
359+ self .num_kv_heads * self .head_size )
372360 actual_seq_lens = attn_metadata .seq_lens_list
361+ attn_args = {
362+ "query" : q ,
363+ "key" : k ,
364+ "value" : v ,
365+ "actual_seq_lengths_kv" : actual_seq_lens ,
366+ "block_table" : attn_metadata .block_tables ,
367+ "num_heads" : self .num_heads ,
368+ "scale" : self .scale ,
369+ "input_layout" : "BSH" ,
370+ "num_key_value_heads" : self .num_kv_heads ,
371+ "block_size" : self .block_size ,
372+ }
373+
374+ # Prepare tensors for attention output
375+ # TODO: Refactor this to step-level instead of layer-level
376+ attn_output = torch .empty (num_tokens ,
377+ 1 ,
378+ self .hidden_size ,
379+ dtype = output .dtype ,
380+ device = output .device )
381+ softmax_lse = torch .empty (num_tokens ,
382+ dtype = output .dtype ,
383+ device = output .device )
384+
385+ # Get workspace from cache or calculate it if not present.
386+ workspace = graph_params .workspaces .get (num_tokens )
387+ if workspace is None :
388+ workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
389+ ** attn_args )
390+ graph_params .workspaces [num_tokens ] = workspace
373391
374392 forward_context = get_forward_context ()
375393 if not forward_context .capturing :
376- workspace = (
377- torch_npu .
378- _npu_fused_infer_attention_score_get_max_workspace (
379- q ,
380- k ,
381- v ,
382- actual_seq_lengths_kv = actual_seq_lens ,
383- block_table = attn_metadata .block_tables ,
384- num_heads = self .num_heads ,
385- scale = self .scale ,
386- input_layout = "BSH" ,
387- num_key_value_heads = self .num_kv_heads ,
388- block_size = 128 ,
389- ))
390- graph_params .workspaces [num_tokens ] = workspace
391- softmax_lse = torch .empty (num_tokens ,
392- dtype = output .dtype ,
393- device = output .device )
394- attn_output = torch .empty (num_tokens ,
395- 1 ,
396- self .hidden_size ,
397- dtype = output .dtype ,
398- device = output .device )
394+ # Execute attention kernel directly in non-capturing mode
399395 torch .ops .npu .npu_fused_infer_attention_score .out (
400- q ,
401- k ,
402- v ,
403396 workspace = workspace ,
404- actual_seq_lengths_kv = actual_seq_lens ,
405- block_table = attn_metadata .block_tables ,
406- num_heads = self .num_heads ,
407- scale = self .scale ,
408- input_layout = "BSH" ,
409- num_key_value_heads = self .num_kv_heads ,
410- block_size = 128 ,
411397 out = [attn_output , softmax_lse ],
412- )
413- output [:, :, :] = attn_output .view (
414- num_tokens , self .num_heads ,
415- self .head_size )[:, :, :]
398+ ** attn_args )
416399 else :
400+ # Handle graph capturing mode
417401 stream = torch_npu .npu .current_stream ()
418- workspace = graph_params .workspaces .get (num_tokens )
419- if workspace is None :
420- workspace = (
421- torch_npu .
422- _npu_fused_infer_attention_score_get_max_workspace (
423- q ,
424- k ,
425- v ,
426- actual_seq_lengths_kv = actual_seq_lens ,
427- block_table = attn_metadata .block_tables ,
428- num_heads = self .num_heads ,
429- scale = self .scale ,
430- input_layout = "BSH" ,
431- num_key_value_heads = self .num_kv_heads ,
432- block_size = 128 ,
433- ))
434- graph_params .workspaces [num_tokens ] = workspace
435- softmax_lse = torch .empty (num_tokens ,
436- dtype = output .dtype ,
437- device = output .device )
438402
439403 event = torch .npu .ExternalEvent ()
440404 event .wait (stream )
441405 event .reset (stream )
442406 graph_params .events [num_tokens ].append (event )
443407
444- attn_output = torch .empty (num_tokens ,
445- 1 ,
446- self .hidden_size ,
447- dtype = output .dtype ,
448- device = output .device )
449408 graph_params .attn_params [num_tokens ].append (
450409 (q , k , v , actual_seq_lens ,
451410 attn_metadata .block_tables , self .num_heads ,
452411 self .scale , self .num_kv_heads , attn_output ,
453412 softmax_lse ))
413+
454414 torch .npu .graph_task_group_begin (stream )
455415 torch .ops .npu .npu_fused_infer_attention_score .out (
456- q ,
457- k ,
458- v ,
459416 workspace = workspace ,
460- actual_seq_lengths_kv = actual_seq_lens ,
461- block_table = attn_metadata .block_tables ,
462- num_heads = self .num_heads ,
463- scale = self .scale ,
464- input_layout = "BSH" ,
465- num_key_value_heads = self .num_kv_heads ,
466- block_size = 128 ,
467417 out = [attn_output , softmax_lse ],
468- )
418+ ** attn_args )
469419 handle = torch .npu .graph_task_group_end (stream )
470- output [:, :, :] = attn_output .view (
471- num_tokens , self .num_heads ,
472- self .head_size )[:, :, :]
473420 graph_params .handles [num_tokens ].append (handle )
474421
475- # output = torch_npu.npu_incre_flash_attention(
476- # q, k, v,
477- # num_key_value_heads=self.num_kv_heads,
478- # num_heads=self.num_heads,
479- # actual_seq_lengths=attn_metadata.seq_lens,
480- # scale_value=self.scale,
481- # block_table=attn_metadata.block_tables,
482- # input_layout="BSH",
483- # block_size=128
484- # )
485- # attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
486- # q,
487- # k,
488- # v,
489- # actual_seq_lengths_kv=actual_seq_lens,
490- # block_table=attn_metadata.block_tables,
491- # num_heads=self.num_heads,
492- # scale=self.scale,
493- # input_layout="BSH",
494- # num_key_value_heads=self.num_kv_heads,
495- # block_size=128,
496- # )
497- # output[:, :, :] = attn_output.view(num_tokens, self.num_heads, self.head_size)[:, :, :]
422+ # Reshape output to match the expected format
423+ output .copy_ (
424+ attn_output .view (num_tokens , self .num_heads ,
425+ self .head_size ))
498426 else :
499427 torch_npu ._npu_paged_attention (
500428 query = query ,
0 commit comments