Skip to content

Optimized Paged Attention for HPU#4

Closed
kzawora-intel wants to merge 1 commit intohabana_mainfrom
private/kzawora/paged_attention_optimization
Closed

Optimized Paged Attention for HPU#4
kzawora-intel wants to merge 1 commit intohabana_mainfrom
private/kzawora/paged_attention_optimization

Conversation

@kzawora-intel
Copy link

This PR adds Paged Attention implementation for Gaudi2. In standalone benchmark (B=32, Mkv=1024, Hkv=32, K=128) it showed 200x performance improvement over current implementation:

Current (naive) version:

INFO:root:[REQ:0][B:32, Mq:1, Mkv:996, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 106.500s
INFO:root:[REQ:1][B:32, Mq:1, Mkv:1006, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 46.476s
INFO:root:[REQ:2][B:32, Mq:1, Mkv:1015, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 21.430s
INFO:root:[REQ:3][B:32, Mq:1, Mkv:953, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 12.854s
INFO:root:[REQ:4][B:32, Mq:1, Mkv:1008, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 9.996s
INFO:root:[REQ:5][B:32, Mq:1, Mkv:980, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 5.928s
INFO:root:[REQ:6][B:32, Mq:1, Mkv:1024, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 8.595s
INFO:root:[REQ:7][B:32, Mq:1, Mkv:1015, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 5.835s
INFO:root:[REQ:8][B:32, Mq:1, Mkv:995, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 6.296s
INFO:root:[REQ:9][B:32, Mq:1, Mkv:1023, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 5.833s
INFO:root:[ALL:10][B:(32-32), Mq:1, Mkv:(953-1024), Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 229.758s

Optimized version:

INFO:root:[REQ:0][B:32, Mq:1, Mkv:996, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 0.959s
INFO:root:[REQ:1][B:32, Mq:1, Mkv:1006, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 0.022s
INFO:root:[REQ:2][B:32, Mq:1, Mkv:1015, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 0.020s
INFO:root:[REQ:3][B:32, Mq:1, Mkv:953, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 0.020s
INFO:root:[REQ:4][B:32, Mq:1, Mkv:1008, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 0.020s
INFO:root:[REQ:5][B:32, Mq:1, Mkv:980, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 0.020s
INFO:root:[REQ:6][B:32, Mq:1, Mkv:1024, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 0.020s
INFO:root:[REQ:7][B:32, Mq:1, Mkv:1015, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 0.020s
INFO:root:[REQ:8][B:32, Mq:1, Mkv:995, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 0.021s
INFO:root:[REQ:9][B:32, Mq:1, Mkv:1023, Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 0.020s
INFO:root:[ALL:10][B:(32-32), Mq:1, Mkv:(953-1024), Hq:32, Hkv:32, K:128] paged_attention_op_time (hpu): 1.149s

@kzawora-intel kzawora-intel changed the title Optimize Paged Attention for HPU Optimized Paged Attention for HPU Feb 20, 2024
Comment on lines +166 to +169
## hard override for filler. These blocks would contribute nothing to the output due to zero attention_probs and will clog up compute resources
# with torch.profiler.record_function('block_seq_len_check'):
# if (block_index - 2) * block_size > torch.max(context_lens):
# break

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this block. It's unlikely to come back into use, even if a new check against overcomputing is introduced.

Comment on lines +162 to +163
with torch.profiler.record_function(f"block_loop"):
with torch.profiler.record_function("seq_index_fill"):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these profiler labels have any effect when not profiling this code?

# single block attn weight of shape [B, Hq, Mq(=1), block_size], equivalent to attn_weights_blocks[i]
attn_weights = attn_weights_blocks.index_select(0, block_index).squeeze(0)

with torch.profiler.record_function("fetch_block_table"):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prevalence of the profiler labels is reducing code readability, especially in this part of the code where each record_function covers a single statement.

Comment on lines +139 to +141
query = query_in
key_cache = key_cache_in
value_cache = value_cache_in

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The *_in args were used in a prior version to facilitate type casting, since FP16 was causing precision issues in softmax. If that is no longer the case, then the arg names should be changed and these lines should be removed.

htorch.core.mark_step()

# Cleanup out-of-bound weights and values
attn_weights_blocks_filler = torch.finfo(query.dtype).min

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "minimum value of query type" filler is also used elsewhere in this code (e.g. the attn_weights_blocks definition) and should probably be defined and commented on at the start of the function. The question also arises as to the type restrictions, since the required behavior for this filler is to produce a zero under exp(), and that behavior is not observed on FP16.

attn_masks=None,
) -> None:
sanitize_values = True
device = query_in.device

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this statement produce the device type ("hpu") or a device identifier ("hpu:0")? This is relevant to its use in further checks.

output.add_(out)
if device == "hpu":
htorch.core.mark_step()
return output.to(dtype=query_in.dtype)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no type casting of the query in the current code, so this casting of the output is superfluous.

@madamczyk-intel
Copy link

Since #14 was merged, I guess we can close this.

@kzawora-intel kzawora-intel added the habana Issues or PRs submitted by Habana Labs label Sep 20, 2024
@kzawora-intel kzawora-intel deleted the private/kzawora/paged_attention_optimization branch October 7, 2024 13:14
iboiko-habana added a commit that referenced this pull request Feb 28, 2025
tvoas referenced this pull request in tvoas/vllm-fork Mar 7, 2025
kzawora-intel pushed a commit that referenced this pull request Jul 10, 2025
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

habana Issues or PRs submitted by Habana Labs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants