-
-
Notifications
You must be signed in to change notification settings - Fork 12.4k
[Kernel] Triton implementation of causal-conv1d for Mamba-based models #18218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
ad83738
f4c56bf
dfa7159
61d7ed9
8882cef
775e561
939a823
29b7941
7bfe0e8
52d601c
081a8be
9eb1cc3
091b31e
bfabaae
da660f0
7af7f58
10e332c
ecb3a2c
107911a
bfc2f28
ef21b3d
400e669
f0be762
4cfb12d
64ee33d
19586c5
e3192e8
8aad208
a0d2170
4d1bb63
679eb1c
c782f25
6d0e77a
20a34c5
82091a7
6784173
6e8d966
089b10b
761bdea
7448f0d
5e41d6b
bbef3ac
6527b9d
129b32d
37f801a
a208d04
a798b14
736eeba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,12 +5,14 @@ | |
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from einops import rearrange | ||
|
|
||
| from tests.kernels.utils import opcheck | ||
| from vllm import _custom_ops as ops # noqa: F401 | ||
| from vllm.attention.backends.utils import PAD_SLOT_ID | ||
| from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( | ||
| causal_conv1d_fn, causal_conv1d_update) | ||
| causal_conv1d_fn, causal_conv1d_fn_triton, causal_conv1d_update, | ||
| causal_conv1d_update_triton) | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
|
|
@@ -435,3 +437,237 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, | |
| causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), | ||
| padded_state_indices, has_initial_states, | ||
| final_states, activation) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("itype", | ||
| [torch.float32, torch.float16, torch.bfloat16]) | ||
| @pytest.mark.parametrize("silu_activation", [False, True]) | ||
| @pytest.mark.parametrize("has_bias", [False, True]) | ||
| @pytest.mark.parametrize("seqlen", [1]) | ||
| @pytest.mark.parametrize("width", [2, 3, 4]) | ||
| @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) | ||
| # tests correctness in case subset of the sequences are padded | ||
| @pytest.mark.parametrize("with_padding", [True, False]) | ||
| @pytest.mark.parametrize("batch_size", [3]) | ||
| def test_causal_conv1d_update_with_batch_gather_vllm(batch_size, with_padding, | ||
|
||
| dim, width, seqlen, | ||
| has_bias, silu_activation, | ||
| itype): | ||
| device = "cuda" | ||
| rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) | ||
| if itype == torch.bfloat16: | ||
| rtol, atol = 1e-2, 5e-2 | ||
|
|
||
| # set seed | ||
| current_platform.seed_everything(0) | ||
|
|
||
| padding = 5 if with_padding else 0 | ||
| padded_batch_size = batch_size + padding | ||
| # total_entries = number of cache line | ||
| total_entries = 10 * batch_size | ||
|
|
||
| channel_last = True | ||
| if not channel_last: | ||
| x = torch.randn(padded_batch_size, | ||
| dim, | ||
| seqlen, | ||
| device=device, | ||
| dtype=itype) | ||
| else: | ||
|
||
| # x will be (batch, dim, seqlen) with contiguous along dim-axis | ||
| x = torch.randn(padded_batch_size, | ||
| seqlen, | ||
| dim, | ||
| device=device, | ||
| dtype=itype).transpose(1, 2) | ||
|
|
||
| x_ref = x.clone() | ||
|
|
||
| conv_state_indices = torch.randperm(total_entries)[:batch_size].to( | ||
| dtype=torch.int32, device=device) | ||
| unused_states_bool = torch.ones(total_entries, | ||
| dtype=torch.bool, | ||
| device=device) | ||
| unused_states_bool[conv_state_indices] = False | ||
| padded_state_indices = torch.concat([ | ||
| conv_state_indices, | ||
| torch.as_tensor( | ||
| [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) | ||
| ], | ||
| dim=0) | ||
|
|
||
| if not channel_last: | ||
| conv_state = torch.randn(total_entries, | ||
| dim, | ||
| width - 1, | ||
| device=device, | ||
| dtype=itype) | ||
| else: | ||
|
||
| # conv_state will be (cache_lines, dim, state_len) | ||
| # with contiguous along dim-axis | ||
| conv_state = torch.randn(total_entries, | ||
| width - 1, | ||
| dim, | ||
| device=device, | ||
| dtype=itype).transpose(1, 2) | ||
|
|
||
| conv_state_for_padding_test = conv_state.clone() | ||
|
|
||
| weight = torch.randn(dim, width, device=device, dtype=itype) | ||
| bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None | ||
| conv_state_ref = conv_state[conv_state_indices, :].detach().clone() | ||
| activation = None if not silu_activation else "silu" | ||
|
|
||
| out = causal_conv1d_update_triton(x, | ||
| conv_state, | ||
| weight, | ||
| bias, | ||
| activation=activation, | ||
| conv_state_indices=padded_state_indices, | ||
| pad_slot_id=PAD_SLOT_ID) | ||
| out_ref = causal_conv1d_update_ref(x_ref[:batch_size], | ||
| conv_state_ref, | ||
| weight, | ||
| bias, | ||
| activation=activation) | ||
|
|
||
| assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) | ||
| assert torch.equal(conv_state[unused_states_bool], | ||
| conv_state_for_padding_test[unused_states_bool]) | ||
| assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("itype", [torch.bfloat16]) | ||
| @pytest.mark.parametrize("silu_activation", [True]) | ||
| @pytest.mark.parametrize("has_bias", [True]) | ||
| @pytest.mark.parametrize("width", [4]) | ||
| @pytest.mark.parametrize('seqlen', [8, 16, 784, 1024, 2048, 2049, 4096]) | ||
| @pytest.mark.parametrize('dim', [64, 4096]) | ||
| @pytest.mark.parametrize('with_padding', [True, False]) | ||
| @pytest.mark.parametrize('batch', [4]) | ||
| def test_causal_conv1d_varlen_vllm(batch, with_padding, dim, seqlen, width, | ||
| has_bias, silu_activation, itype): | ||
|
||
| device = "cuda" | ||
| torch.cuda.empty_cache() | ||
| rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) | ||
| if itype == torch.bfloat16: | ||
| rtol, atol = 1e-2, 5e-2 | ||
| # set seed | ||
| current_platform.seed_everything(0) | ||
| seqlens = [] | ||
| batch_size = batch | ||
| padding = 3 if with_padding else 0 | ||
| padded_batch_size = batch_size + padding | ||
| nsplits = padded_batch_size - 1 | ||
|
|
||
| eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values | ||
|
|
||
| seqlens.append( | ||
| torch.diff( | ||
| torch.cat( | ||
| [torch.tensor([-1]), eos_pos, | ||
| torch.tensor([seqlen - 1])])).tolist()) | ||
| assert sum(seqlens[-1]) == seqlen | ||
| assert all(s > 0 for s in seqlens[-1]) | ||
|
|
||
| total_entries = batch_size * 10 | ||
| cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) | ||
| cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], | ||
| dim=0) | ||
| channel_last = True | ||
| if not channel_last: | ||
| x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, | ||
| dtype=itype)[:, 4096:4096 + dim, :] | ||
| else: | ||
|
||
| x = rearrange( | ||
| torch.randn(1, seqlen, 4096 + dim + 64, device=device, | ||
| dtype=itype), "b s d -> b d s")[:, 4096:4096 + dim, :] | ||
|
|
||
| weight = torch.randn(dim, width, device=device, dtype=itype) | ||
|
|
||
| bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None | ||
| x_ref = x.clone() | ||
| weight_ref = weight.clone() | ||
| bias_ref = bias.clone() if bias is not None else None | ||
| activation = None if not silu_activation else "silu" | ||
| if not channel_last: | ||
| final_states = torch.randn(total_entries, | ||
| dim, | ||
| width - 1, | ||
| device=x.device, | ||
| dtype=x.dtype) | ||
| else: | ||
|
||
| final_states = torch.randn(total_entries, | ||
| width - 1, | ||
| dim, | ||
| device=x.device, | ||
| dtype=x.dtype).transpose(1, 2) | ||
| final_states_ref = final_states.clone() | ||
| has_initial_states = torch.randint(0, | ||
| 2, (cumsum.shape[0] - 1, ), | ||
| dtype=torch.bool, | ||
| device=x.device) | ||
| state_indices = torch.randperm(total_entries, | ||
| dtype=torch.int32, | ||
| device=x.device)[:batch_size] | ||
| padded_state_indices = torch.concat([ | ||
| state_indices, | ||
| torch.as_tensor( | ||
| [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), | ||
| ], | ||
| dim=-1) | ||
| out = causal_conv1d_fn_triton(x.squeeze(0), | ||
| weight, | ||
| bias=bias, | ||
| conv_states=final_states, | ||
| query_start_loc=cumsum.cuda(), | ||
| cache_indices=padded_state_indices, | ||
| has_initial_states=has_initial_states, | ||
| activation=activation, | ||
| pad_slot_id=PAD_SLOT_ID) | ||
|
|
||
| out_ref = [] | ||
| out_ref_b = [] | ||
|
|
||
| splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)] | ||
| for i in range(len(seqlens[0])): | ||
| x_s = [v[i].unsqueeze(0) for v in splits][0] | ||
| if padded_state_indices[i] == PAD_SLOT_ID: | ||
| continue | ||
| out_ref_b.append( | ||
| causal_conv1d_ref( | ||
| x_s, | ||
| weight_ref, | ||
| bias_ref, | ||
| activation=activation, | ||
| return_final_states=True, | ||
| final_states_out=final_states_ref[ | ||
| padded_state_indices[i]].unsqueeze(0), | ||
| initial_states=final_states_ref[padded_state_indices[i]]. | ||
| unsqueeze(0) if has_initial_states[i] else None)) | ||
| out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) | ||
| out_ref_tensor = torch.cat(out_ref, dim=0) | ||
|
|
||
| try: | ||
| assert torch.allclose(final_states[state_indices], | ||
| final_states_ref[state_indices], | ||
| rtol=rtol, | ||
| atol=atol) | ||
| print("Passed conv_state") | ||
| except Exception as e: | ||
| print("FAILED conv_state") | ||
| raise e | ||
| unpadded_out = out[:, :out_ref_tensor.shape[-1]] | ||
| try: | ||
| assert torch.allclose(unpadded_out, | ||
| out_ref_tensor, | ||
| rtol=rtol, | ||
| atol=atol) | ||
| except Exception as e: | ||
| input( | ||
| "Passed conv_state, but failed output: Press Enter to continue...") | ||
thoangtrvn marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| nz = out_ref_tensor.squeeze(0) - unpadded_out | ||
| non_zero_indices = torch.nonzero(nz) | ||
| print('nonzero indices :', non_zero_indices) | ||
| raise e | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,9 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import math | ||
| from dataclasses import dataclass | ||
| from typing import Optional | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from vllm.attention.backends.abstract import AttentionMetadata | ||
|
|
@@ -22,6 +24,31 @@ class Mamba2Metadata: | |
| chunk_indices: torch.Tensor | ||
| chunk_offsets: torch.Tensor | ||
|
|
||
| num_cache_lines: Optional[int] = None | ||
| stride_istate_seq: Optional[int] = None | ||
| stride_istate_dim: Optional[int] = None | ||
| stride_istate_token: Optional[int] = None | ||
| seqlens: Optional[np.ndarray] = None | ||
| padded_batch: Optional[int] = None | ||
| nums_dict: Optional[dict] = None | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in a batch of requests, a prefill request can be processed in parallel where each Triton program handles BLOCK_M tokens. Depending on the choice of BLOCK_M, the values in I added the documents accordingly. |
||
| is_channel_last: bool = True | ||
| stride_w_dim: Optional[int] = None | ||
| stride_w_width: Optional[int] = None | ||
| width: Optional[int] = None | ||
| np2_statelen: Optional[int] = None | ||
| stride_x_seq: Optional[int] = 0 | ||
| stride_x_dim: Optional[int] = None | ||
| stride_x_token: Optional[int] = None | ||
| dim: Optional[int] = None | ||
| cu_seqlen: Optional[int] = None | ||
| out: Optional[torch.Tensor] = None | ||
|
||
| stride_o_seq: Optional[int] = 0 | ||
| stride_o_dim: Optional[int] = None | ||
| stride_o_token: Optional[int] = None | ||
| MAX_NUM_PROGRAMS: int = 1024 | ||
| batch_ptr: Optional[torch.tensor] = None | ||
| token_chunk_offset_ptr: Optional[torch.tensor] = None | ||
|
||
|
|
||
|
|
||
| def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, | ||
| chunk_size: int, | ||
|
|
@@ -62,7 +89,9 @@ def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, | |
| def prepare_mamba2_metadata( | ||
| chunk_size: int, | ||
| attn_metadata: AttentionMetadata, | ||
| mamba2_metadata=None, | ||
| ) -> Mamba2Metadata: | ||
| # ruff: noqa: E501 | ||
tlrmchlsmth marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # compute number of prefill and decode requests | ||
| # NOTE: in V0 we assume prefills are before decodes | ||
|
|
@@ -78,6 +107,12 @@ def prepare_mamba2_metadata( | |
|
|
||
| # Compute seq_idx, chunk_indices and chunk_offsets for prefill only | ||
| if num_prefills > 0: | ||
| # NOTE: currently it is assumed prefill requests come before decode requests -> we can use ':num_prefills' slicing | ||
| # TODO: maybe revert back to the original code (below) if above no longer holds | ||
| # has_initial_states = attn_metadata.context_lens_tensor > 0 | ||
| # zero_init_indices = mamba_cache_params.state_indices_tensor[~has_initial_states] | ||
| # mamba_cache_params.ssm_state[zero_init_indices] = 0 | ||
| # initial_states = mamba_cache_params.ssm_state[mamba_cache_params.state_indices_tensor] | ||
|
||
| if (isinstance(attn_metadata, | ||
| (FlashAttentionMetadata, XFormersMetadata, | ||
| PlaceholderAttentionMetadata)) | ||
|
|
@@ -103,6 +138,21 @@ def prepare_mamba2_metadata( | |
| _query_start_loc_to_chunk_indices_offsets( | ||
| query_start_loc, chunk_size, num_prefill_tokens) | ||
|
|
||
| if mamba2_metadata is not None: | ||
| mamba2_metadata.has_initial_states = has_initial_states | ||
| mamba2_metadata.prep_initial_states = prep_initial_states | ||
| mamba2_metadata.chunk_size = chunk_size | ||
| mamba2_metadata.seq_idx = seq_idx | ||
| mamba2_metadata.chunk_indices = chunk_indices | ||
| mamba2_metadata.chunk_offsets = chunk_offsets | ||
| # We use 2 reset flags: | ||
| # * mamba2_metadata.width is None # update config at first run (never change whole session for a given model) | ||
| # (become available at first layer, e.g. conv_weights) | ||
| # * mamba2_metadata.cu_seqlen is None # update config specific to (each input) | ||
| # (become available at first layer, e.g. conv_weights) | ||
| mamba2_metadata.cu_seqlen = None # suppose to be updated at each input | ||
|
|
||
| return mamba2_metadata | ||
| return Mamba2Metadata(has_initial_states=has_initial_states, | ||
| prep_initial_states=prep_initial_states, | ||
| chunk_size=chunk_size, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we test more than just seqlen 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes, I can add more to the test code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done