11from typing import List , Optional
22
3+ from flash_attn .flash_attention import FlashAttention
34import torch
45import torch .nn as nn
56
@@ -14,20 +15,7 @@ def __init__(self, scale: float) -> None:
1415 super ().__init__ ()
1516 self .scale = float (scale )
1617
17- def _masked_attention (
18- self ,
19- query : torch .Tensor , # [num_queries, num_heads, head_size]
20- key : torch .Tensor , # [num_keys, num_heads, head_size]
21- value : torch .Tensor , # [num_keys, num_heads, head_size]
22- attn_mask : Optional [torch .Tensor ] = None , # [num_queries, num_keys]
23- ) -> torch .Tensor : # [num_queries, num_heads, head_size]
24- query = query * self .scale
25- attn = torch .einsum ('qhd,khd->hqk' , query , key )
26- if attn_mask is not None :
27- attn = attn + attn_mask
28- attn = torch .softmax (attn , dim = - 1 )
29- out = torch .einsum ('hqk,khd->qhd' , attn , value )
30- return out
18+ self .flash_attn = FlashAttention (softmax_scale = self .scale )
3119
3220 def multi_query_kv_attention (
3321 self ,
@@ -37,21 +25,31 @@ def multi_query_kv_attention(
3725 value : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
3826 prompt_lens : List [int ],
3927 ) -> None :
40- # FIXME(woosuk): Replace the following with a custom op.
41- start_idx = 0
28+ if query .dtype == torch .float :
29+ raise ValueError ('The float data type is not supported by '
30+ 'FlashAttention. Use the half data type instead.' )
31+ head_size = query .shape [2 ]
32+ if head_size > 128 :
33+ raise ValueError ('FlashAttention does not support head_size > 128.' )
34+
35+ device = query .device
36+ prefix_sum = [0 ]
4237 for prompt_len in prompt_lens :
43- out = output [start_idx :start_idx + prompt_len ]
44- q = query [start_idx :start_idx + prompt_len ]
45- k = key [start_idx :start_idx + prompt_len ]
46- v = value [start_idx :start_idx + prompt_len ]
47-
48- attention_mask = torch .triu (
49- torch .ones (q .shape [0 ], k .shape [0 ]), diagonal = 1 ) * - 1e5
50- attention_mask = attention_mask .to (dtype = q .dtype , device = q .device )
51- attention_out = self ._masked_attention (q , k , v , attention_mask )
52- out .copy_ (attention_out , non_blocking = True )
53-
54- start_idx += prompt_len
38+ prefix_sum .append (prefix_sum [- 1 ] + prompt_len )
39+ prefix_sum = torch .tensor (prefix_sum , dtype = torch .int , device = device )
40+ max_prompt_len = max (prompt_lens )
41+
42+ # FIXME(woosuk): Unnecessary copy. Optimize this.
43+ qkv = torch .stack ([query , key , value ], dim = 1 )
44+ out = self .flash_attn (
45+ qkv ,
46+ cu_seqlens = prefix_sum ,
47+ max_s = max_prompt_len ,
48+ causal = True ,
49+ )[0 ]
50+ num_tokens = prefix_sum [- 1 ]
51+ # FIXME(woosuk): Unnecessary copy. Optimize this.
52+ output [:num_tokens ].copy_ (out , non_blocking = True )
5553
5654 def single_query_cached_kv_attention (
5755 self ,
@@ -61,6 +59,14 @@ def single_query_cached_kv_attention(
6159 value_cache : torch .Tensor , # [num_blocks, num_heads, head_size, block_size]
6260 input_metadata : InputMetadata ,
6361 ) -> None :
62+ head_size = value_cache .shape [2 ]
63+ supported_head_sizes = [32 , 64 , 80 , 96 , 128 , 160 , 192 , 256 ]
64+ if head_size not in supported_head_sizes :
65+ raise ValueError (f'head_size ({ head_size } ) is not supported by '
66+ 'the single_query_cached_kv_attention kernel. '
67+ 'Use one of the following head sizes: '
68+ f'{ supported_head_sizes } .' )
69+
6470 block_size = value_cache .shape [3 ]
6571 attention_ops .single_query_cached_kv_attention (
6672 output ,
@@ -101,8 +107,9 @@ def forward(
101107 output = output .view (- 1 , num_heads , head_size )
102108
103109 # Compute the attention op for prompts.
104- self .multi_query_kv_attention (
105- output , query , key , value , input_metadata .prompt_lens )
110+ if input_metadata .num_prompts > 0 :
111+ self .multi_query_kv_attention (
112+ output , query , key , value , input_metadata .prompt_lens )
106113
107114 # Wait until the cache op is done.
108115 if cache_event is not None :
0 commit comments