Skip to content
19 changes: 19 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,25 @@ def get_aiter_allreduce_max_size(cls) -> int | None:
# https://github.com/ROCm/aiter/blob/6a0e7b26ccf33164785531212cc2ec2cde0b9243/aiter/dist/device_communicators/custom_all_reduce.py#L272-L273
return int(cls._ALL_REDUCE_MAX_SIZE / 2)

@classmethod
@if_aiter_supported
def are_gdn_triton_kernels_available(cls) -> bool:
"""Check if AITER Triton kernels for GDN attention are importable.

These are optional Triton kernels (conv1d fast-path, gated delta net)
used by GatedDeltaNetAttention's decode fast-path. They may be absent
in older aiter builds.
"""
if not cls._AITER_ENABLED:
return False
try:
import aiter.ops.triton.causal_conv1d_update_single_token # noqa: F401
import aiter.ops.triton.gated_delta_net # noqa: F401

return True
except (ImportError, ModuleNotFoundError):
return False

@staticmethod
@if_aiter_supported
def register_ops_once() -> None:
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/layers/fla/ops/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def chunk_gated_delta_rule_fwd(
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
core_attn_out: torch.Tensor | None = None,
):
g = chunk_local_cumsum(
g, chunk_size=FLA_CHUNK_SIZE, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices
Expand Down Expand Up @@ -77,6 +78,7 @@ def chunk_gated_delta_rule_fwd(
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
core_attn_out=core_attn_out,
)
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
Expand All @@ -102,6 +104,7 @@ def forward(
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
core_attn_out: torch.Tensor | None = None,
):
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
Expand All @@ -119,9 +122,15 @@ def forward(
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
core_attn_out=core_attn_out,
)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
if core_attn_out is not None:
assert not torch.is_grad_enabled(), (
"core_attn_out buffer reuse is only supported for inference"
)
assert q.dtype == o.dtype, "Incompatible dtype for inplace computation"
return o.to(q.dtype), final_state


Expand All @@ -139,6 +148,7 @@ def chunk_gated_delta_rule(
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
core_attn_out: torch.Tensor | None = None,
):
r"""
Args:
Expand Down Expand Up @@ -230,5 +240,6 @@ def chunk_gated_delta_rule(
chunk_indices,
chunk_offsets,
use_qk_l2norm_in_kernel,
core_attn_out,
)
return o, final_state
9 changes: 8 additions & 1 deletion vllm/model_executor/layers/fla/ops/chunk_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def chunk_fwd_o(
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_size: int = FLA_CHUNK_SIZE,
core_attn_out: torch.Tensor | None = None,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
Expand All @@ -158,7 +159,13 @@ def chunk_fwd_o(
if scale is None:
scale = k.shape[-1] ** -0.5

o = torch.empty_like(v)
if core_attn_out is not None:
assert core_attn_out.numel() >= v.numel(), (
f"core_attn_out too small: {core_attn_out.numel()} < {v.numel()}"
)
o = core_attn_out[: v.numel()].view(*v.shape)
else:
o = torch.empty_like(v)

def grid(meta):
return (triton.cdiv(V, meta["BV"]), NT, B * H)
Expand Down
Loading
Loading