Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions vllm/model_executor/layers/fla/ops/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .cumsum import chunk_local_cumsum
from .l2norm import l2norm_fwd
from .solve_tril import solve_tril
from .utils import SUPPRESS_LEVEL, input_guard
from .utils import FLA_CHUNK_SIZE, SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd


Expand All @@ -30,20 +30,32 @@ def chunk_gated_delta_rule_fwd(
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
g = chunk_local_cumsum(
g, chunk_size=FLA_CHUNK_SIZE, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices
)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(
k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
k=k,
beta=beta,
g=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
output_dtype=torch.float32,
)
A = solve_tril(
A=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, output_dtype=k.dtype
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
Expand All @@ -53,6 +65,8 @@ def chunk_gated_delta_rule_fwd(
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
)
o = chunk_fwd_o(
q=q,
Expand All @@ -62,6 +76,7 @@ def chunk_gated_delta_rule_fwd(
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
)
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
Expand All @@ -84,6 +99,8 @@ def forward(
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
Expand All @@ -100,6 +117,8 @@ def forward(
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
Expand All @@ -117,6 +136,8 @@ def chunk_gated_delta_rule(
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
Expand Down Expand Up @@ -206,6 +227,8 @@ def chunk_gated_delta_rule(
initial_state,
output_final_state,
cu_seqlens,
chunk_indices,
chunk_offsets,
use_qk_l2norm_in_kernel,
)
return o, final_state
21 changes: 9 additions & 12 deletions vllm/model_executor/layers/fla/ops/chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .index import prepare_chunk_indices, prepare_chunk_offsets
from .op import exp
from .utils import use_cuda_graph
from .utils import FLA_CHUNK_SIZE, use_cuda_graph

NUM_WARPS = [2, 4, 8, 16]

Expand Down Expand Up @@ -286,30 +286,27 @@ def chunk_gated_delta_rule_fwd_h(
gk: torch.Tensor | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
chunk_size: int = FLA_CHUNK_SIZE,
save_new_value: bool = True,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.
B, T, Hg, K, V = *k.shape, u.shape[-1]
H = u.shape[-2]
BT = chunk_size

chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = (
len(cu_seqlens) - 1,
len(chunk_indices),
prepare_chunk_offsets(cu_seqlens, BT),
)
N, NT = len(cu_seqlens) - 1, len(chunk_indices)
if chunk_offsets is None:
chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
assert K <= 256, "current kernel does not support head dimension larger than 256."

h = k.new_empty(B, NT, H, V, K)
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/fla/ops/chunk_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,14 @@ def chunk_fwd_o(
g: torch.Tensor | None = None, # cumsum of log decay
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_size: int = FLA_CHUNK_SIZE,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
if scale is None:
scale = k.shape[-1] ** -0.5
Expand Down
12 changes: 8 additions & 4 deletions vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .index import prepare_chunk_indices
from .op import exp
from .utils import FLA_CHUNK_SIZE


@triton.heuristics(
Expand Down Expand Up @@ -103,7 +104,8 @@ def chunk_scaled_dot_kkt_fwd(
g: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
chunk_indices: torch.Tensor | None = None,
chunk_size: int = FLA_CHUNK_SIZE,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
r"""
Expand All @@ -119,6 +121,9 @@ def chunk_scaled_dot_kkt_fwd(
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_indices (torch.Tensor):
Pre-computed chunk indices. If None and cu_seqlens is provided,
computed internally. Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
Expand All @@ -132,9 +137,8 @@ def chunk_scaled_dot_kkt_fwd(
B, T, Hg, K = k.shape
H = beta.shape[-1]
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
Expand Down
35 changes: 23 additions & 12 deletions vllm/model_executor/layers/fla/ops/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def chunk_local_cumsum_scalar(
chunk_size: int,
reverse: bool = False,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
) -> torch.Tensor:
Expand All @@ -172,10 +173,9 @@ def chunk_local_cumsum_scalar(
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
"chunk_size must be a power of 2"
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B * H)
Expand All @@ -199,23 +199,21 @@ def chunk_local_cumsum_vector(
chunk_size: int,
reverse: bool = False,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
) -> torch.Tensor:
if head_first:
B, H, T, S = g.shape
else:
B, T, H, S = g.shape
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
"chunk_size must be a power of 2"
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
BT = chunk_size
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)

Expand Down Expand Up @@ -247,6 +245,7 @@ def chunk_local_cumsum(
chunk_size: int,
reverse: bool = False,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
**kwargs,
Expand All @@ -257,11 +256,23 @@ def chunk_local_cumsum(
)
if len(g.shape) == 3:
return chunk_local_cumsum_scalar(
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
g,
chunk_size,
reverse,
cu_seqlens,
chunk_indices,
head_first,
output_dtype,
)
elif len(g.shape) == 4:
return chunk_local_cumsum_vector(
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
g,
chunk_size,
reverse,
cu_seqlens,
chunk_indices,
head_first,
output_dtype,
)
else:
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/layers/fla/ops/kda.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .l2norm import l2norm_fwd
from .op import exp, log
from .solve_tril import solve_tril
from .utils import is_amd
from .utils import FLA_CHUNK_SIZE, is_amd

BT_LIST_AUTOTUNE = [32, 64, 128]
NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
Expand Down Expand Up @@ -721,7 +721,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
beta: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
chunk_size: int = FLA_CHUNK_SIZE,
output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""
Expand Down Expand Up @@ -1178,7 +1178,7 @@ def chunk_kda_fwd(
output_final_state: bool,
cu_seqlens: torch.Tensor | None = None,
):
chunk_size = 64
chunk_size = FLA_CHUNK_SIZE
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
# the intra Aqk is kept in fp32
# the computation has very marginal effect on the entire throughput
Expand All @@ -1189,6 +1189,7 @@ def chunk_kda_fwd(
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
output_dtype=torch.float32,
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/layers/fla/ops/solve_tril.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def merge_16x16_to_64x64_inverse_kernel(
def solve_tril(
A: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
"""
Expand All @@ -518,6 +519,8 @@ def solve_tril(
[B, T, H, BT], where BT should only be 16, 32, or 64.
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor. Default: `None`.
chunk_indices (torch.Tensor):
Pre-computed chunk indices. Default: `None`.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`.
If `None`, the output dtype will be the same as the input dtype.
Expand All @@ -529,9 +532,8 @@ def solve_tril(
output_dtype = A.dtype if output_dtype is None else output_dtype

B, T, H, BT = A.shape
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)

Ai = torch.zeros_like(A, dtype=output_dtype)
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/fla/ops/wy_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ def recompute_w_u_fwd(
g_cumsum: torch.Tensor,
A: torch.Tensor,
cu_seqlens: torch.Tensor | None,
chunk_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
BT = A.shape[-1]

chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64
BV = 64
Expand Down
Loading
Loading