diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index 73cba7f9035c..02e48921d419 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -16,7 +16,7 @@ from .cumsum import chunk_local_cumsum from .l2norm import l2norm_fwd from .solve_tril import solve_tril -from .utils import SUPPRESS_LEVEL, input_guard +from .utils import FLA_CHUNK_SIZE, SUPPRESS_LEVEL, input_guard from .wy_fast import recompute_w_u_fwd @@ -30,13 +30,24 @@ def chunk_gated_delta_rule_fwd( initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, + chunk_offsets: torch.Tensor | None = None, ): - g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + g = chunk_local_cumsum( + g, chunk_size=FLA_CHUNK_SIZE, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices + ) # obtain WY representation. u is actually the new v. A = chunk_scaled_dot_kkt_fwd( - k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32 + k=k, + beta=beta, + g=g, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + output_dtype=torch.float32, + ) + A = solve_tril( + A=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, output_dtype=k.dtype ) - A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) w, u = recompute_w_u_fwd( k=k, v=v, @@ -44,6 +55,7 @@ def chunk_gated_delta_rule_fwd( A=A, g_cumsum=g, cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, ) h, v_new, final_state = chunk_gated_delta_rule_fwd_h( k=k, @@ -53,6 +65,8 @@ def chunk_gated_delta_rule_fwd( initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, ) o = chunk_fwd_o( q=q, @@ -62,6 +76,7 @@ def chunk_gated_delta_rule_fwd( g=g, scale=scale, cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, ) if SUPPRESS_LEVEL < 3: return g, o, A, final_state, None, None, None @@ -84,6 +99,8 @@ def forward( initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, + chunk_offsets: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ): if use_qk_l2norm_in_kernel: @@ -100,6 +117,8 @@ def forward( initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, ) ctx.scale = scale ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel @@ -117,6 +136,8 @@ def chunk_gated_delta_rule( initial_state: torch.Tensor = None, output_final_state: bool = False, cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, + chunk_offsets: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ): r""" @@ -206,6 +227,8 @@ def chunk_gated_delta_rule( initial_state, output_final_state, cu_seqlens, + chunk_indices, + chunk_offsets, use_qk_l2norm_in_kernel, ) return o, final_state diff --git a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py index ce60ca46f6c9..574f6f25173f 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py +++ b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py @@ -14,7 +14,7 @@ from .index import prepare_chunk_indices, prepare_chunk_offsets from .op import exp -from .utils import use_cuda_graph +from .utils import FLA_CHUNK_SIZE, use_cuda_graph NUM_WARPS = [2, 4, 8, 16] @@ -286,9 +286,11 @@ def chunk_gated_delta_rule_fwd_h( gk: torch.Tensor | None = None, initial_state: torch.Tensor | None = None, output_final_state: bool = False, - chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + chunk_size: int = FLA_CHUNK_SIZE, save_new_value: bool = True, cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, + chunk_offsets: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # This kernel is slightly different from fla to support Q/K with different head numbers. # In fla, Q/K always have the same head number, so Hg is always equal to H. @@ -296,20 +298,15 @@ def chunk_gated_delta_rule_fwd_h( H = u.shape[-2] BT = chunk_size - chunk_indices = ( - prepare_chunk_indices(cu_seqlens, chunk_size) - if cu_seqlens is not None - else None - ) + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) # N: the actual number of sequences in the batch with either equal or variable lengths if cu_seqlens is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None else: - N, NT, chunk_offsets = ( - len(cu_seqlens) - 1, - len(chunk_indices), - prepare_chunk_offsets(cu_seqlens, BT), - ) + N, NT = len(cu_seqlens) - 1, len(chunk_indices) + if chunk_offsets is None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT) assert K <= 256, "current kernel does not support head dimension larger than 256." h = k.new_empty(B, NT, H, V, K) diff --git a/vllm/model_executor/layers/fla/ops/chunk_o.py b/vllm/model_executor/layers/fla/ops/chunk_o.py index aab1ee006d4d..d812ec433720 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_o.py +++ b/vllm/model_executor/layers/fla/ops/chunk_o.py @@ -146,14 +146,14 @@ def chunk_fwd_o( g: torch.Tensor | None = None, # cumsum of log decay scale: float | None = None, cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, chunk_size: int = FLA_CHUNK_SIZE, ) -> torch.Tensor: B, T, Hg, K, V = *q.shape, v.shape[-1] H = v.shape[-2] BT = chunk_size - chunk_indices = ( - prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - ) + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) if scale is None: scale = k.shape[-1] ** -0.5 diff --git a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py index 31bd489ebd87..3f7628487d69 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py +++ b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py @@ -14,6 +14,7 @@ from .index import prepare_chunk_indices from .op import exp +from .utils import FLA_CHUNK_SIZE @triton.heuristics( @@ -103,7 +104,8 @@ def chunk_scaled_dot_kkt_fwd( g: torch.Tensor | None = None, beta: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None, - chunk_size: int = 64, + chunk_indices: torch.Tensor | None = None, + chunk_size: int = FLA_CHUNK_SIZE, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: r""" @@ -119,6 +121,9 @@ def chunk_scaled_dot_kkt_fwd( cu_seqlens (torch.Tensor): The cumulative sequence lengths of the input tensor. Default: None + chunk_indices (torch.Tensor): + Pre-computed chunk indices. If None and cu_seqlens is provided, + computed internally. Default: None chunk_size (int): The chunk size. Default: 64. output_dtype (torch.dtype): @@ -132,9 +137,8 @@ def chunk_scaled_dot_kkt_fwd( B, T, Hg, K = k.shape H = beta.shape[-1] BT = chunk_size - chunk_indices = ( - prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - ) + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) diff --git a/vllm/model_executor/layers/fla/ops/cumsum.py b/vllm/model_executor/layers/fla/ops/cumsum.py index 13238020cbd9..b0820104b1a1 100644 --- a/vllm/model_executor/layers/fla/ops/cumsum.py +++ b/vllm/model_executor/layers/fla/ops/cumsum.py @@ -162,6 +162,7 @@ def chunk_local_cumsum_scalar( chunk_size: int, reverse: bool = False, cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, head_first: bool = False, output_dtype: torch.dtype | None = torch.float, ) -> torch.Tensor: @@ -172,10 +173,9 @@ def chunk_local_cumsum_scalar( assert chunk_size == 2 ** (chunk_size.bit_length() - 1), ( "chunk_size must be a power of 2" ) + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) BT = chunk_size - chunk_indices = ( - prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - ) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) grid = (NT, B * H) @@ -199,6 +199,7 @@ def chunk_local_cumsum_vector( chunk_size: int, reverse: bool = False, cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, head_first: bool = False, output_dtype: torch.dtype | None = torch.float, ) -> torch.Tensor: @@ -206,16 +207,13 @@ def chunk_local_cumsum_vector( B, H, T, S = g.shape else: B, T, H, S = g.shape - BT = chunk_size - chunk_indices = ( - prepare_chunk_indices(cu_seqlens, chunk_size) - if cu_seqlens is not None - else None - ) - NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) assert chunk_size == 2 ** (chunk_size.bit_length() - 1), ( "chunk_size must be a power of 2" ) + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + BT = chunk_size + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) @@ -247,6 +245,7 @@ def chunk_local_cumsum( chunk_size: int, reverse: bool = False, cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, head_first: bool = False, output_dtype: torch.dtype | None = torch.float, **kwargs, @@ -257,11 +256,23 @@ def chunk_local_cumsum( ) if len(g.shape) == 3: return chunk_local_cumsum_scalar( - g, chunk_size, reverse, cu_seqlens, head_first, output_dtype + g, + chunk_size, + reverse, + cu_seqlens, + chunk_indices, + head_first, + output_dtype, ) elif len(g.shape) == 4: return chunk_local_cumsum_vector( - g, chunk_size, reverse, cu_seqlens, head_first, output_dtype + g, + chunk_size, + reverse, + cu_seqlens, + chunk_indices, + head_first, + output_dtype, ) else: raise ValueError( diff --git a/vllm/model_executor/layers/fla/ops/kda.py b/vllm/model_executor/layers/fla/ops/kda.py index b8c07d1dc896..67cd0231d6e9 100644 --- a/vllm/model_executor/layers/fla/ops/kda.py +++ b/vllm/model_executor/layers/fla/ops/kda.py @@ -23,7 +23,7 @@ from .l2norm import l2norm_fwd from .op import exp, log from .solve_tril import solve_tril -from .utils import is_amd +from .utils import FLA_CHUNK_SIZE, is_amd BT_LIST_AUTOTUNE = [32, 64, 128] NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32] @@ -721,7 +721,7 @@ def chunk_kda_scaled_dot_kkt_fwd( beta: torch.Tensor | None = None, scale: float | None = None, cu_seqlens: torch.Tensor | None = None, - chunk_size: int = 64, + chunk_size: int = FLA_CHUNK_SIZE, output_dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: r""" @@ -1178,7 +1178,7 @@ def chunk_kda_fwd( output_final_state: bool, cu_seqlens: torch.Tensor | None = None, ): - chunk_size = 64 + chunk_size = FLA_CHUNK_SIZE g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) # the intra Aqk is kept in fp32 # the computation has very marginal effect on the entire throughput @@ -1189,6 +1189,7 @@ def chunk_kda_fwd( beta=beta, scale=scale, cu_seqlens=cu_seqlens, + chunk_size=chunk_size, output_dtype=torch.float32, ) A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) diff --git a/vllm/model_executor/layers/fla/ops/solve_tril.py b/vllm/model_executor/layers/fla/ops/solve_tril.py index da85aab19207..8d3811ca4c17 100644 --- a/vllm/model_executor/layers/fla/ops/solve_tril.py +++ b/vllm/model_executor/layers/fla/ops/solve_tril.py @@ -507,6 +507,7 @@ def merge_16x16_to_64x64_inverse_kernel( def solve_tril( A: torch.Tensor, cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, output_dtype: torch.dtype = torch.float, ) -> torch.Tensor: """ @@ -518,6 +519,8 @@ def solve_tril( [B, T, H, BT], where BT should only be 16, 32, or 64. cu_seqlens (torch.Tensor): The cumulative sequence lengths of the input tensor. Default: `None`. + chunk_indices (torch.Tensor): + Pre-computed chunk indices. Default: `None`. output_dtype (torch.dtype): The dtype of the output tensor. Default: `torch.float`. If `None`, the output dtype will be the same as the input dtype. @@ -529,9 +532,8 @@ def solve_tril( output_dtype = A.dtype if output_dtype is None else output_dtype B, T, H, BT = A.shape - chunk_indices = ( - prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - ) + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) Ai = torch.zeros_like(A, dtype=output_dtype) diff --git a/vllm/model_executor/layers/fla/ops/wy_fast.py b/vllm/model_executor/layers/fla/ops/wy_fast.py index 6baa08ab4996..52d2b28195a8 100644 --- a/vllm/model_executor/layers/fla/ops/wy_fast.py +++ b/vllm/model_executor/layers/fla/ops/wy_fast.py @@ -123,14 +123,14 @@ def recompute_w_u_fwd( g_cumsum: torch.Tensor, A: torch.Tensor, cu_seqlens: torch.Tensor | None, + chunk_indices: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, Hg, K, V = *k.shape, v.shape[-1] H = v.shape[-2] BT = A.shape[-1] - chunk_indices = ( - prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - ) + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) BK = 64 BV = 64 diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index 2b952e10e6ed..7b878fb91e7e 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -161,6 +161,8 @@ def forward_cuda( initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, + chunk_offsets: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = True, ): return fi_chunk_gated_delta_rule( @@ -185,6 +187,8 @@ def forward_native( initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, + chunk_offsets: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = True, ): return fla_chunk_gated_delta_rule( @@ -196,6 +200,8 @@ def forward_native( initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, ) @@ -832,6 +838,8 @@ def _forward_core( initial_state=initial_state, output_final_state=True, cu_seqlens=non_spec_query_start_loc, + chunk_indices=attn_metadata.chunk_indices, + chunk_offsets=attn_metadata.chunk_offsets, use_qk_l2norm_in_kernel=True, ) # Init cache diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 41c69deb43a4..308debb31cd0 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -7,6 +7,11 @@ import torch from vllm.config import VllmConfig +from vllm.model_executor.layers.fla.ops.index import ( + prepare_chunk_indices, + prepare_chunk_offsets, +) +from vllm.model_executor.layers.fla.ops.utils import FLA_CHUNK_SIZE from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -63,6 +68,10 @@ class GDNAttentionMetadata: num_accepted_tokens: torch.Tensor | None = None # shape: [batch,] + # Pre-computed FLA chunk metadata (avoids GPU->CPU sync in prepare_chunk_indices) + chunk_indices: torch.Tensor | None = None + chunk_offsets: torch.Tensor | None = None + # The following attributes are for triton implementation of causal_conv1d nums_dict: dict | None = None batch_ptr: torch.Tensor | None = None @@ -305,6 +314,20 @@ def build( # type: ignore[override] assert num_accepted_tokens is not None num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] + chunk_indices: torch.Tensor | None = None + chunk_offsets: torch.Tensor | None = None + if num_prefills > 0: + # Only prefill batches use FLA chunk ops. + # Pre-compute on CPU and async-copy to GPU to avoid + # GPU→CPU sync (.tolist()) in prepare_chunk_indices. + gpu_device = query_start_loc.device + chunk_indices = prepare_chunk_indices( + non_spec_query_start_loc_cpu, FLA_CHUNK_SIZE + ).to(device=gpu_device, non_blocking=True) + chunk_offsets = prepare_chunk_offsets( + non_spec_query_start_loc_cpu, FLA_CHUNK_SIZE + ).to(device=gpu_device, non_blocking=True) + if num_prefills > 0: has_initial_state = context_lens_tensor > 0 if spec_sequence_masks is not None: @@ -405,6 +428,8 @@ def build( # type: ignore[override] num_spec_decode_tokens=num_spec_decode_tokens, num_actual_tokens=m.num_actual_tokens, has_initial_state=has_initial_state, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, spec_query_start_loc=spec_query_start_loc, non_spec_query_start_loc=non_spec_query_start_loc, spec_state_indices_tensor=spec_state_indices_tensor,