Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/fla/ops/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def chunk_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
Expand Down Expand Up @@ -84,7 +84,7 @@ def forward(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
Expand Down Expand Up @@ -117,7 +117,7 @@ def chunk_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
Expand All @@ -141,7 +141,7 @@ def chunk_gated_delta_rule(
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
Expand Down Expand Up @@ -171,7 +171,7 @@ def chunk_gated_delta_rule(
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fla/ops/chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def chunk_gated_delta_rule_fwd_h(
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fla/ops/chunk_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def chunk_fwd_o(
h: torch.Tensor,
g: torch.Tensor | None = None, # cumsum of log decay
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
g: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
Expand All @@ -116,7 +116,7 @@ def chunk_scaled_dot_kkt_fwd(
The beta tensor of shape `[B, T, H]`.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/fla/ops/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def fused_recurrent_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
Expand Down Expand Up @@ -489,7 +489,7 @@ def forward(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
Expand Down Expand Up @@ -521,7 +521,7 @@ def fused_recurrent_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
Expand Down Expand Up @@ -549,7 +549,7 @@ def fused_recurrent_gated_delta_rule(
inplace_final_state: bool:
Whether to store the final state in-place to save memory.
Default: `True`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
ssm_state_indices (Optional[torch.Tensor]):
Expand Down Expand Up @@ -583,7 +583,7 @@ def fused_recurrent_gated_delta_rule(
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def fused_sigmoid_gating_delta_rule_update(
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
Expand Down
10 changes: 3 additions & 7 deletions vllm/model_executor/layers/fla/ops/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@


@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
def prepare_lens(cu_seqlens: torch.Tensor) -> torch.Tensor:
return cu_seqlens[1:] - cu_seqlens[:-1]


@tensor_cache
def prepare_chunk_indices(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
def prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
indices = torch.cat(
[
torch.arange(n)
Expand All @@ -33,9 +31,7 @@ def prepare_chunk_indices(


@tensor_cache
def prepare_chunk_offsets(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
def prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
return torch.cat(
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
).cumsum(-1)
16 changes: 8 additions & 8 deletions vllm/model_executor/layers/fla/ops/kda.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def fused_recurrent_kda_fwd(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
Expand Down Expand Up @@ -116,7 +116,7 @@ def fused_recurrent_kda(
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
use_qk_l2norm_in_kernel: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.LongTensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -720,7 +720,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
gk: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -734,7 +734,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
The beta tensor of shape `[B, T, H]`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
Expand Down Expand Up @@ -964,7 +964,7 @@ def recompute_w_u_fwd(
A: torch.Tensor,
q: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
Expand Down Expand Up @@ -1132,7 +1132,7 @@ def chunk_gla_fwd_o_gk(
h: torch.Tensor,
o: torch.Tensor,
scale: float,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
):
B, T, H, K, V = *q.shape, v.shape[-1]
Expand Down Expand Up @@ -1176,7 +1176,7 @@ def chunk_kda_fwd(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
):
chunk_size = 64
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
Expand Down Expand Up @@ -1236,7 +1236,7 @@ def chunk_kda(
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
**kwargs,
):
if scale is None:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fla/ops/wy_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def recompute_w_u_fwd(
beta: torch.Tensor,
g_cumsum: torch.Tensor,
A: torch.Tensor,
cu_seqlens: torch.LongTensor | None,
cu_seqlens: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
Expand Down
21 changes: 13 additions & 8 deletions vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def fi_chunk_gated_delta_rule(
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
from flashinfer.gdn_prefill import (
Expand Down Expand Up @@ -178,7 +178,7 @@ def forward_cuda(
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
return fi_chunk_gated_delta_rule(
Expand All @@ -202,7 +202,7 @@ def forward_native(
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
return fla_chunk_gated_delta_rule(
Expand Down Expand Up @@ -707,8 +707,13 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
v = torch.randn(
1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
)
g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
# NOTE: g and beta must have the same dtypes as during
# inference, so we construct them with the same function
# (fused_gdn_gating). dummy_a and dummy_b are throwaway
# inputs required by that function.
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype)
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias)
state = torch.zeros(
1,
num_v_heads,
Expand All @@ -717,7 +722,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
device=device,
dtype=state_dtype,
)
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long)
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32)

try:
self.chunk_gated_delta_rule(
Expand All @@ -727,7 +732,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
g=g,
beta=beta,
initial_state=state,
output_final_state=False,
output_final_state=True,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
Expand All @@ -747,7 +752,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
self.prefix,
)
finally:
del q, k, v, g, beta, state, cu_seqlens
del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens

torch.accelerator.empty_cache()

Expand Down
Loading