FA3 variable length attention sort/swizzle#82
Conversation
Signed-off-by: Jay Shah <jayhshah@gmail.com>
…or virtual batch metadata Signed-off-by: Jay Shah <jayhshah@gmail.com>
Signed-off-by: Jay Shah <jayhshah@gmail.com>
Signed-off-by: Jay Shah <jayhshah@gmail.com>
Signed-off-by: Jay Shah <jayhshah@gmail.com>
Signed-off-by: Jay Shah <jayhshah@gmail.com>
Signed-off-by: Jay Shah <jayhshah@gmail.com>
Signed-off-by: Jay Shah <jayhshah@gmail.com>
Signed-off-by: Jay Shah <jayhshah@gmail.com>
Signed-off-by: Jay Shah <jayhshah@gmail.com>
Signed-off-by: Jay Shah <jayhshah@gmail.com>
Signed-off-by: Jay Shah <jayhshah@gmail.com>
Signed-off-by: Jay Shah <jayhshah@gmail.com>
Signed-off-by: Jay Shah <jayhshah@gmail.com>
| seqlen = params.seqlen; | ||
| if constexpr (Prepared) { | ||
| return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 | ||
| ? cute::ceil_div(params.prepare_seqlen_q_ptr[batch_idx], kBlockM) : 0; |
There was a problem hiding this comment.
right now vLLM is a bit annoying and we actually compute the attention metadata (and as a result the mha_fwd_get_scheduler_metadata) before knowing how many requests we sill pad too; this means that scheduler metadata will be for a batch size than what params.b is at runtime. This is normally fine since cu_seqlens is padded to make sure all requests up to max batch size are seqlen_q == 0 so FA returns before touching any bad memory; however if this reads garbage from prepare_seqlen_q_ptr this might break? We can probably zero the metadata here: https://github.com/neuralmagic/vllm/blob/a75c6e034abf00603fba527625e44baab7b42f80/vllm/v1/attention/backends/flash_attn.py#L333-L338
(this is a historical artifact of thinking that piecewise cudagraphs would be enough in V1 and we wouldn't need attention to be in a cudagraph; so this may be re-architected in the near future)
There was a problem hiding this comment.
Actually we might have to do a more aggressive refactor on the vLLM side since I think an even bigger problem is that all of the offsets will be wrong:
int sort_offset = b_rounded * (use_dynamic_split ? 2 : 1);
int head_swizzle_offset = b_rounded * (num_prepare_batch_vectors - 1);
int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors;
There was a problem hiding this comment.
the other option would be to make the scheduler metadata an "Array of Structs" instead of a "Struct of Arrays", then I the offsets wouldn't be dependent on the batch size the scheduler used (and we could more easily just 0 out the rest of the metadata)
how hard do you think this would be / how badly do you think this would hurt perf
There was a problem hiding this comment.
the other option would be to make the scheduler metadata an "Array of Structs" instead of a "Struct of Arrays", then I the offsets wouldn't be dependent on the batch size the scheduler used (and we could more easily just 0 out the rest of the metadata)
how hard do you think this would be / how badly do you think this would hurt perf
I could write out as int4 array instead, but wouldn't have coalesced accesses when reading back in, so would like to avoid if at all possible.
Can we pass in a max batch size to set the offsets correctly?
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
|
CI now passing on vllm-project/vllm#23465 |
vllm side mirror of Dao-AILab#1823