[Kernel] Chunk-aligned mamba2#24683
Conversation
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
|
Server: Client: Results (main): Results (tpa-mamba-aligned): |
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
|
Server: Benchmark: Branch Branch |
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
tlrmchlsmth
left a comment
There was a problem hiding this comment.
PR looks great at first pass. Love to see more red than green.
|
In the figure in the PR description, why does A1.a fall at the beginning of the chunk rather than the end? I thought A0 should be ahead of it rather than behind |
@tlrmchlsmth A0 isn't actually added to the chunk, it has already been prefilled and doesn't need to be computed again. We just need to partition A1 in such a way that |
|
Do the padded regions get loaded at all?
makes sense. So then the A0-sized padded region could overlap with another chunk, or it could fall off the end of the KV cache tensor, right? Do we mask off the loads of the padded region as well? |
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
No, padding is maybe the wrong word. There isn't any actual padding of tensors in memory here. Masking would probably be a better word. If we have 5 chunks like in the above example, we would launch a Triton kernel with a grid size of We are basically trading off a bit of extra compute in order to get intermediate states at exactly where we want them within each sequence. It turns out it isn't really a trade-off since it strips out so much complexity, it is a net-win. |
Yes, if we don't introduce the padding/masking it will lead to (a) having multiple sequences within the same chunk and (b) needing this whole mapping between "logical" and "physical" chunks to track where everything is.
Yes, we mask off the loads exactly (example: https://github.com/tdoublep/vllm/blob/tpa-aligned-mamba/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py#L231) |
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
…fter chunked-aligned mamba is merged (PR vllm-project#24683) Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Purpose
This PR changes the way that the mamba2 kernels split the batch into "chunks". The change ensures that (a) no chunk ever contains more than one sequence, and (b) all intermediate states are computed at the chunk boundaries within each sequence.
This change is useful for three reasons:
The downside is that it introduces some "virtual" padding inside the chunks. We don't actually pad anything in GPU memory, we just potentially need to use a larger grid when launching kernels and may do some redundant compute. However, this padding is bounded to at most one chunk per sequence, and my initial experiments suggest it really doesn't hurt a lot. In fact, we actually see a significant speedup because we skip the call to the final "varlen" kernel. We follow a very similar approach for working with varlen batches in the Triton attention kernels, so this kind of technique is not without precedent.
TODO:
seq_idxcan be made simpler - we just need to keep track of the seq_idx per chunkchunk_indicesandchunk_offsetsSimple example for two sequences A and B is shown below. A0 and B0 represent the chunks that were prefilled at the previous step, and A1 and B1 are the new chunks we want to prefill in this iteration.
The idea is that for sequence A, we first take enough tokens from the new part (A1) to ensure that, when taking together with the precomputed part (A0), the state is chunked-aligned. Then we fill chunks with new tokens (from A1) until we run out, at which we pad to the chunk boundary. Then repeat for B.
Test Plan
See correctness + benchmarking below.
Test Result
See correctness + benchmarking below.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.