Conversation
| ## hard override for filler. These blocks would contribute nothing to the output due to zero attention_probs and will clog up compute resources | ||
| # with torch.profiler.record_function('block_seq_len_check'): | ||
| # if (block_index - 2) * block_size > torch.max(context_lens): | ||
| # break |
There was a problem hiding this comment.
Please remove this block. It's unlikely to come back into use, even if a new check against overcomputing is introduced.
| with torch.profiler.record_function(f"block_loop"): | ||
| with torch.profiler.record_function("seq_index_fill"): |
There was a problem hiding this comment.
Do these profiler labels have any effect when not profiling this code?
| # single block attn weight of shape [B, Hq, Mq(=1), block_size], equivalent to attn_weights_blocks[i] | ||
| attn_weights = attn_weights_blocks.index_select(0, block_index).squeeze(0) | ||
|
|
||
| with torch.profiler.record_function("fetch_block_table"): |
There was a problem hiding this comment.
The prevalence of the profiler labels is reducing code readability, especially in this part of the code where each record_function covers a single statement.
| query = query_in | ||
| key_cache = key_cache_in | ||
| value_cache = value_cache_in |
There was a problem hiding this comment.
The *_in args were used in a prior version to facilitate type casting, since FP16 was causing precision issues in softmax. If that is no longer the case, then the arg names should be changed and these lines should be removed.
| htorch.core.mark_step() | ||
|
|
||
| # Cleanup out-of-bound weights and values | ||
| attn_weights_blocks_filler = torch.finfo(query.dtype).min |
There was a problem hiding this comment.
The "minimum value of query type" filler is also used elsewhere in this code (e.g. the attn_weights_blocks definition) and should probably be defined and commented on at the start of the function. The question also arises as to the type restrictions, since the required behavior for this filler is to produce a zero under exp(), and that behavior is not observed on FP16.
| attn_masks=None, | ||
| ) -> None: | ||
| sanitize_values = True | ||
| device = query_in.device |
There was a problem hiding this comment.
Does this statement produce the device type ("hpu") or a device identifier ("hpu:0")? This is relevant to its use in further checks.
| output.add_(out) | ||
| if device == "hpu": | ||
| htorch.core.mark_step() | ||
| return output.to(dtype=query_in.dtype) |
There was a problem hiding this comment.
There's no type casting of the query in the current code, so this casting of the output is superfluous.
|
Since #14 was merged, I guess we can close this. |
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This PR adds Paged Attention implementation for Gaudi2. In standalone benchmark (B=32, Mkv=1024, Hkv=32, K=128) it showed 200x performance improvement over current implementation:
Current (naive) version:
Optimized version: