-
Notifications
You must be signed in to change notification settings - Fork 126
Optimized Paged Attention for HPU #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,72 +25,247 @@ def gelu_new(output, input): | |
| def gelu_fast(output, input): | ||
| raise NotImplementedError | ||
|
|
||
| def paged_attention_v1(query_in, key_cache_in, value_cache_in, head_mapping, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes, attn_masks=None) -> None: | ||
| query = query_in.bfloat16() | ||
| key_cache = key_cache_in.bfloat16() | ||
| value_cache = value_cache_in.bfloat16() | ||
| def _paged_attention_masked_fill( | ||
| attn_weights_blocks, | ||
| value_blocks, | ||
| context_lens, | ||
| attn_weights_blocks_filler, | ||
| value_blocks_filler, | ||
| block_size, | ||
| max_num_blocks_per_seq, | ||
| num_seqs, | ||
| sanitize_values, | ||
| device, | ||
| ): | ||
| # NOTE (kzawora): this code performs unconditinal out-of-bound cleanup on attention weights. | ||
| # It was pretty insane to write and is probably hard to read, but it allows us to avoid | ||
| # recompilations and D2H-H2D copies on Gaudi2, making it very efficient. | ||
|
|
||
| # First, we're filling full out-of bound blocks. We want to create 2D mask [num_seqs, max_num_blocks_per_seq] | ||
| # indicating which blocks need to be cleaned | ||
|
|
||
| # Create [num_seqs, max_num_blocks_per_seq] tensor of block indices per each sequence, | ||
| # which we'll then transform into a boolean tensor with mask | ||
| block_indices = torch.arange(max_num_blocks_per_seq, dtype=torch.int64, device=device).view(1, -1) | ||
| block_indices = block_indices.expand(num_seqs, block_indices.size(1)) | ||
|
|
||
| # Create mask with 1s for all blocks that are fully out of bound, and 0s for the rest. | ||
| # In order to broadcast the mask across all dimensions, we need to transpose it and | ||
| # view it as 5D tensor with ones in broadcasted dimensions (max_num_blocks_per_seq, num_seqs, 1, 1, 1) | ||
| attn_weights_blocks_mask = (block_indices >= (torch.ceil(context_lens / block_size)).unsqueeze(-1)).T.view( | ||
| max_num_blocks_per_seq, num_seqs, 1, 1, 1 | ||
| ) | ||
|
|
||
| # Apply block mask to attenton weights | ||
| attn_weights_blocks.masked_fill_(attn_weights_blocks_mask, attn_weights_blocks_filler) | ||
|
|
||
| # We're done with filling full OoB blocks. Now, we need to fill out-of-bound values within last blocks | ||
| # The problem here is that now, we'll need to fetch all last blocks of each sequence, and fill | ||
| # the out-of-bound activation in the last dimension (block_size). This is pretty hard to do without | ||
| # loops and conditons. | ||
|
|
||
| # Collect last block indices. This will include blocks that are both partially, and fully filled. | ||
| # We expect this index to be in bounds (< max_blocks_per_seq). | ||
| last_block_indices = (torch.ceil((context_lens / block_size)) - 1).to(torch.int64) | ||
|
|
||
| # Gather indices of last blocks. We will collect plenty of superfluous blocks, | ||
| # as we'll fetch all (num_seq) indices per each sequence. This will result in | ||
| # (num_seq, num_seq, num_query_heads, 1, block_size) tensor. | ||
| last_blocks = attn_weights_blocks.index_select(0, last_block_indices) | ||
|
|
||
| # Extract only relevant blocks. Since dim0 and dim1 are the same, and we passed last_block_indices in order, | ||
| # we can reduce these dimensions by extracting the diagonal value. torch.diagonal returns the extracted value | ||
| # as the last dimension, so we'll need to permute the tensor to get it back to the first one. | ||
| # We expect to transform the source (num_seq, num_seq, num_query_heads, 1, block_size) tensor into | ||
| # (num_seq, num_query_heads, 1, block_size) tensor, with the first dimension containing each sequence's last block. | ||
| last_blocks_diag = torch.diagonal(last_blocks, dim1=0, dim2=1, offset=0).permute((3, 0, 1, 2)) | ||
|
|
||
| # Similarly to block mask, we'll create s 2D tensor of token indices per each block, | ||
| # which we'll then transform into a boolean tensor with mask | ||
| seq_indices = torch.arange(block_size, dtype=torch.int64, device=device).view(1, -1) | ||
| seq_indices = seq_indices.expand(num_seqs, seq_indices.size(1)) | ||
|
|
||
| # Create mask with 1s for all tokens that are fully out of bound, and 0s for the rest. | ||
| # We apply a bias of block_size for sequences that have context length divisible by block_size, | ||
| # as we don't want to clear anything within their last block - it is fully filled | ||
| last_block_offsets = (context_lens % block_size + block_size * (context_lens % block_size == 0)).view(-1, 1) | ||
| seq_mask = seq_indices >= last_block_offsets | ||
|
|
||
| # Apply block mask to weights to diagonal (num_seq, num_query_heads, 1, block_size) tensor. | ||
| last_blocks_diag.masked_fill_(seq_mask.view(num_seqs, 1, 1, block_size), attn_weights_blocks_filler) | ||
|
|
||
| # Scatter the (num_seq, num_query_heads, 1, block_size) tensor back into attn_weights_blocks using | ||
| # the same indices as we did in gathering. Each "row" will be stored at (last_block_index[i],i), | ||
| # where i is sequence index. | ||
| seq_idx = torch.arange(num_seqs, dtype=torch.int64, device=device) | ||
| edge_indices = (last_block_indices, seq_idx) | ||
| attn_weights_blocks.index_put_(edge_indices, last_blocks_diag) | ||
|
|
||
| # NOTE(kzawora): For whatever reason, we also need to cleanup values, otherwise vLLM inference will return garbage. | ||
| # I'll need to figure out why. For now, we'll follow exactly the same steps as we did with weight cleanup. | ||
| if sanitize_values: | ||
| # For filling fully OoB value blocks, we can use the same mask as we did previously. | ||
| value_blocks.masked_fill_(attn_weights_blocks_mask, value_blocks_filler) | ||
|
|
||
| # Same thing for filling partial blocks - obtain all last blocks for all sequences | ||
| last_values = value_blocks.index_select(0, last_block_indices) # [num_seqs, num_seqs, num_kv_heads, block_size, head_size] | ||
|
|
||
| # Reduce the tensor to have only a single last block per each sequence | ||
| last_values_diag = torch.diagonal(last_values, dim1=0, dim2=1, offset=0).permute((3, 0, 1, 2)) | ||
|
|
||
| # Mask blocks | ||
| last_values_diag.masked_fill_(seq_mask.view(num_seqs, 1, block_size, 1), value_blocks_filler) | ||
|
|
||
| # Put masked blocks back in their place | ||
| value_blocks.index_put_(edge_indices, last_values_diag) | ||
|
|
||
| # Phew, that was a lot. | ||
|
|
||
| def paged_attention_v1( | ||
| query_in, | ||
| key_cache_in, | ||
| value_cache_in, | ||
| head_mapping, | ||
| scale, | ||
| block_tables, | ||
| context_lens, | ||
| block_size, | ||
| max_context_len, | ||
| alibi_slopes, | ||
| attn_masks=None, | ||
| ) -> None: | ||
| sanitize_values = True | ||
| device = query_in.device | ||
| query = query_in | ||
| key_cache = key_cache_in | ||
| value_cache = value_cache_in | ||
|
Comment on lines
+139
to
+141
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| num_kv_heads = value_cache[0].shape[0] | ||
| head_size = value_cache[0].shape[1] | ||
| block_size = value_cache[0].shape[2] | ||
| num_seqs = query.shape[0] | ||
| num_query_heads = query.shape[1] | ||
| max_num_blocks_per_seq = block_tables.shape[1] | ||
|
|
||
| if alibi_slopes or num_query_heads != num_kv_heads: #or attn_masks is None: | ||
| if alibi_slopes: | ||
| raise NotImplementedError | ||
|
|
||
| attn_weights_blocks = [] | ||
| value_blocks = [] | ||
| seq_index = torch.tensor([0], dtype=torch.int64, device="hpu") | ||
|
|
||
| for i in range(0, max_num_blocks_per_seq): | ||
| # FIXME: dynamic hard override for filler. These blocks would contribute nothing to the output due to zero attention_probs and | ||
| # will clog up compute resources. The override itself makes the code unsuitable for graph precompilation | ||
| if (i - 2) * block_size > torch.max(context_lens): | ||
| break | ||
| attn_weights = torch.full((num_seqs, num_query_heads, 1, block_size), torch.finfo(query.dtype).min, dtype=query.dtype, device="hpu") | ||
| values = torch.zeros((num_seqs, num_query_heads, head_size, block_size), dtype=query.dtype, device="hpu") | ||
| for seq_id in range(num_seqs): | ||
| seq_index.fill_(seq_id) | ||
| if i * block_size < context_lens[seq_id]: | ||
|
|
||
| q = torch.index_select(query, 0, seq_index).transpose(0, 1) | ||
| key = torch.index_select(key_cache, 0, block_tables[seq_id][i]).squeeze(0) | ||
| attn_weight = scale * torch.matmul(q, key) | ||
|
|
||
| if attn_masks is not None: | ||
| attn_mask = torch.index_select(attn_masks[i], 0, seq_index) | ||
| attn_weight = torch.masked_fill(attn_weight, ~(attn_mask.unsqueeze(0).to(torch.bool)), torch.finfo(attn_weight.dtype).min) | ||
|
|
||
| # FIXME: these dynamic checks serve to ensure the -inf default value is not overwritten with fillers that would cause errors | ||
| # in logsoftmax computation. A change to custom block multiplication code is required to avoid incurring extra costs here | ||
| if context_lens[seq_id] < (i + 1) * block_size: | ||
| if context_lens[seq_id] - i*block_size < 0: | ||
| attn_weight = torch.finfo(query.dtype).min | ||
| else: | ||
| attn_weight[:, :, context_lens[seq_id] - i*block_size:] = torch.finfo(query.dtype).min | ||
| attn_weights.index_copy_(0, seq_index, attn_weight.unsqueeze(0)) | ||
| value = torch.index_select(value_cache, 0, block_tables[seq_id][i]) | ||
| # FIXME: these checks concern filler values in the V cache and should be removed once the underlying issue is addressed | ||
| value = torch.nan_to_num(value) | ||
| value[value < -1.0e+30] = 0.0 | ||
| values.index_copy_(0, seq_index, value) | ||
| torch.hpu.synchronize() | ||
|
|
||
| attn_weights_blocks.append(attn_weights.reshape(num_seqs * num_query_heads, 1, block_size)) | ||
| value_blocks.append(values.reshape(num_seqs * num_query_heads, head_size, block_size).transpose(1, 2)) | ||
|
|
||
| exp_sum = torch.zeros((*attn_weights_blocks[0].shape[:2], 1), dtype=attn_weights_blocks[0].dtype, device="hpu") | ||
| for x in attn_weights_blocks: | ||
| exp_sum.add_(torch.exp(x).sum(dim=-1, keepdim=True)) | ||
| num_queries_per_kv = num_query_heads // num_kv_heads | ||
| attn_weights_blocks = torch.full( | ||
| (max_num_blocks_per_seq, num_seqs, num_query_heads, 1, block_size), torch.finfo(query.dtype).min, dtype=query.dtype, device=device | ||
| ) | ||
| value_blocks = torch.zeros((max_num_blocks_per_seq, num_seqs, num_kv_heads, block_size, head_size), dtype=query.dtype, device=device) | ||
| seq_index = torch.tensor([0], dtype=torch.int64, device=device) | ||
| block_index = torch.tensor([0], dtype=torch.int64, device=device) | ||
| seq_block_table_idx = torch.tensor([0], dtype=torch.int64, device=device) | ||
| seq_context_len = torch.tensor([0], dtype=torch.int64, device=device) | ||
| for i in range(0, max_num_blocks_per_seq): # can run in parallel | ||
| with torch.profiler.record_function(f"block_loop"): | ||
| with torch.profiler.record_function("seq_index_fill"): | ||
|
Comment on lines
+162
to
+163
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do these profiler labels have any effect when not profiling this code? |
||
| seq_index.fill_(0) | ||
|
|
||
| ## hard override for filler. These blocks would contribute nothing to the output due to zero attention_probs and will clog up compute resources | ||
| # with torch.profiler.record_function('block_seq_len_check'): | ||
| # if (block_index - 2) * block_size > torch.max(context_lens): | ||
| # break | ||
|
Comment on lines
+166
to
+169
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove this block. It's unlikely to come back into use, even if a new check against overcomputing is introduced. |
||
|
|
||
| with torch.profiler.record_function("slice_weights_values"): | ||
| # single block attn weight of shape [B, Hq, Mq(=1), block_size], equivalent to attn_weights_blocks[i] | ||
| attn_weights = attn_weights_blocks.index_select(0, block_index).squeeze(0) | ||
|
|
||
| with torch.profiler.record_function("fetch_block_table"): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The prevalence of the profiler labels is reducing code readability, especially in this part of the code where each |
||
| block_table = block_tables.index_select(1, block_index).squeeze(1) | ||
|
|
||
| with torch.profiler.record_function("fetch_block_keys"): | ||
| keys = torch.index_select(key_cache, 0, block_table) | ||
|
|
||
| with torch.profiler.record_function("fetch_block_values"): | ||
| values = torch.index_select(value_cache, 0, block_table) | ||
|
|
||
| with torch.profiler.record_function("store_block_values"): | ||
| value_blocks.index_copy_(0, block_index, values.permute((0, 1, 3, 2)).unsqueeze(0)) | ||
|
|
||
| with torch.profiler.record_function("repeat_interleave_seq_key_check"): | ||
| if num_queries_per_kv > 1: | ||
| with torch.profiler.record_function("repeat_interleave_seq_key_check_true"): | ||
| # Handle MQA and GQA | ||
| keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) | ||
| with torch.profiler.record_function("gemm_attn_weight"): | ||
| q_bmm = query.reshape(1, num_seqs * num_query_heads, head_size).transpose(0, 1) | ||
| k_bmm = keys.reshape(num_seqs * num_query_heads, head_size, block_size) | ||
| attn_weights_2 = scale * torch.matmul(q_bmm, k_bmm).reshape(num_seqs, num_query_heads, 1, block_size) | ||
| # FIXME: need to restore attention masks in a more sensible way | ||
| if attn_masks is not None: | ||
| for j in range(num_seqs): # can run in parallel | ||
| with torch.profiler.record_function(f"seq_loop"): | ||
| with torch.profiler.record_function("seq_context_len_fill"): | ||
| seq_context_len[0] = context_lens.index_select(0, seq_index)[0] | ||
| with torch.profiler.record_function("context_len_check"): | ||
| if (block_index.mul(block_size)).lt(context_lens.index_select(0, seq_index)): | ||
| with torch.profiler.record_function("context_len_check_true"): | ||
|
|
||
| attn_weight = torch.index_select(attn_weights_2, 0, seq_index).squeeze(0) | ||
| with torch.profiler.record_function("create_attn_mask_check"): | ||
| if attn_masks is not None: | ||
| with torch.profiler.record_function("create_attn_mask_check_true"): | ||
| attn_mask = torch.index_select(attn_masks.index_select(0, block_index), 0, seq_index) | ||
| attn_weight = torch.masked_fill( | ||
| attn_weight, ~(attn_mask.unsqueeze(0).to(torch.bool)), torch.finfo(attn_weight.dtype).min | ||
| ) | ||
|
|
||
| with torch.profiler.record_function("store_seq_weights"): | ||
| attn_weights.index_copy_(0, seq_index, attn_weight.unsqueeze(0)) | ||
|
|
||
| with torch.profiler.record_function("seq_index_inc"): | ||
| seq_index.add_(1) | ||
| else: | ||
| attn_weights = attn_weights_2 | ||
|
|
||
| with torch.profiler.record_function("store_block_weight"): | ||
| attn_weights_blocks.index_copy_(0, block_index, attn_weights.unsqueeze(0)) | ||
|
|
||
| with torch.profiler.record_function("block_index_inc"): | ||
| block_index.add_(1) | ||
| if device == "hpu": | ||
| htorch.core.mark_step() | ||
|
|
||
| # Cleanup out-of-bound weights and values | ||
| attn_weights_blocks_filler = torch.finfo(query.dtype).min | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The "minimum value of query type" filler is also used elsewhere in this code (e.g. the |
||
| value_blocks_filler = 0.0 | ||
| _paged_attention_masked_fill( | ||
| attn_weights_blocks, | ||
| value_blocks, | ||
| context_lens, | ||
| attn_weights_blocks_filler, | ||
| value_blocks_filler, | ||
| block_size, | ||
| max_num_blocks_per_seq, | ||
| num_seqs, | ||
| sanitize_values, | ||
| device, | ||
| ) | ||
|
|
||
| exp_sum = attn_weights_blocks.exp().sum(dim=-1, keepdim=True).sum(0) | ||
|
|
||
| output = torch.zeros_like(query) | ||
| for i in range(len(attn_weights_blocks)): | ||
| attention_probs = torch.exp(attn_weights_blocks[i]) / exp_sum | ||
| value = value_blocks[i] | ||
| out = torch.matmul(attention_probs.to(value.dtype), value).reshape(num_seqs, num_query_heads, head_size) | ||
| output.add_(out) | ||
| htorch.core.mark_step() | ||
| for i in range(max_num_blocks_per_seq): | ||
| with torch.profiler.record_function(f"output_block_loop_{i}"): | ||
| with torch.profiler.record_function("compute_block_softmax"): | ||
| attention_probs = torch.exp(attn_weights_blocks[i]) / exp_sum | ||
| value = value_blocks[i] | ||
| if num_queries_per_kv > 1: | ||
| with torch.profiler.record_function("repeat_interleave_block_value"): | ||
| # Handle MQA and GQA | ||
| value_4d_view = value.reshape(num_seqs, num_kv_heads, block_size, head_size) | ||
| value = torch.repeat_interleave(value_4d_view, num_queries_per_kv, dim=1) | ||
| with torch.profiler.record_function("reshape_for_gemm"): | ||
| attention_probs = attention_probs.to(value.dtype).reshape(num_seqs * num_query_heads, 1, block_size) | ||
| value = value.reshape(num_seqs * num_query_heads, block_size, head_size) | ||
| with torch.profiler.record_function("gemm_out"): | ||
| out = torch.matmul(attention_probs, value).reshape(num_seqs, num_query_heads, head_size) | ||
| with torch.profiler.record_function("gemm_accumulate"): | ||
| output.add_(out) | ||
| if device == "hpu": | ||
| htorch.core.mark_step() | ||
| return output.to(dtype=query_in.dtype) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no type casting of the query in the current code, so this casting of the output is superfluous. |
||
|
|
||
| def rms_norm(out, hidden_states, weight, eps): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this statement produce the device type ("hpu") or a device identifier ("hpu:0")? This is relevant to its use in further checks.