diff --git a/benchmarks/bench_gdn_decode.py b/benchmarks/bench_gdn_decode.py index 5ecafac109..5f4fb16ac9 100644 --- a/benchmarks/bench_gdn_decode.py +++ b/benchmarks/bench_gdn_decode.py @@ -56,7 +56,7 @@ from flashinfer.gdn_decode import ( gated_delta_rule_decode_pretranspose, - gated_delta_rule_decode, + gated_delta_rule_decode_kv, gated_delta_rule_mtp, ) from flashinfer.testing import bench_gpu_time @@ -1111,7 +1111,7 @@ def bench_gdn_decode( if version == "pretranspose": decode_func = gated_delta_rule_decode_pretranspose elif version == "nontranspose": - decode_func = gated_delta_rule_decode + decode_func = gated_delta_rule_decode_kv else: raise ValueError(f"Unknown version: {version}") @@ -1334,7 +1334,7 @@ def bench_comparison( ) flashinfer_times = bench_gpu_time( - lambda: gated_delta_rule_decode( + lambda: gated_delta_rule_decode_kv( q, k, v, state_fi, A_log, a, dt_bias, b, scale, output_fi, use_qk_l2norm ), enable_cupti=True, @@ -1709,7 +1709,7 @@ def verify_correctness( output_fi = torch.empty( batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" ) - gated_delta_rule_decode( + gated_delta_rule_decode_kv( q, k, v, state_fi, A_log, a, dt_bias, b, scale, output_fi, use_qk_l2norm ) @@ -1983,7 +1983,7 @@ def bench_all_layouts( try: times = bench_gpu_time( - lambda: gated_delta_rule_decode( + lambda: gated_delta_rule_decode_kv( q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm ), enable_cupti=True, diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 079dcb0c23..15fb916344 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -96,6 +96,9 @@ cute_dsl_fused_moe_nvfp4 as cute_dsl_fused_moe_nvfp4, CuteDslMoEWrapper as CuteDslMoEWrapper, ) +from .gdn_decode import ( + gated_delta_rule_decode as gated_delta_rule_decode, +) from .gdn_prefill import chunk_gated_delta_rule as chunk_gated_delta_rule from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper from .gemm import bmm_bf16 as bmm_bf16 diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index 0d3410548c..0ccc7657d6 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -27,6 +27,8 @@ - gated_delta_rule_mtp: Multi-token processing (T > 1) for speculative decoding """ +import functools +import warnings from typing import Optional, Tuple import torch @@ -107,7 +109,7 @@ def flashinfer_api(func): # type: ignore[misc] @flashinfer_api -def gated_delta_rule_decode_pretranspose( +def _gated_delta_rule_decode_pretranspose_impl( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -360,7 +362,7 @@ def gated_delta_rule_decode_pretranspose( @flashinfer_api -def gated_delta_rule_decode( +def _gated_delta_rule_decode_kv_impl( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -501,7 +503,7 @@ def gated_delta_rule_decode( @flashinfer_api -def gated_delta_rule_mtp( +def _gated_delta_rule_mtp_impl( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -687,3 +689,391 @@ def gated_delta_rule_mtp( output = output.to(target_dtype) return output, initial_state + + +# ============================================================================ +# Unified GDN Decode API (RFC 5.7) +# ============================================================================ + + +def _check_state_indices_bounds(state_indices: torch.Tensor, pool_size: int) -> None: + """Validate that all state_indices are in [0, pool_size). Raises ValueError if not.""" + if state_indices.numel() == 0: + return + bad = (state_indices < 0) | (state_indices >= pool_size) + if bad.any().item(): + first_bad = state_indices[bad].flatten()[0].item() + raise ValueError( + f"state_indices must be in [0, pool_size={pool_size}); got out-of-range value {first_bad}" + ) + + +@flashinfer_api +def gated_delta_rule_decode( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + state: torch.Tensor, + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + b: torch.Tensor, + state_layout: str = "VK", + state_indices: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + output: Optional[torch.Tensor] = None, + intermediate_states_buffer: Optional[torch.Tensor] = None, + disable_state_update: bool = False, + use_qk_l2norm: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Unified Gated Delta Rule Decode API. + + Single entry point for decode (T=1) and MTP (T>1). Dispatches to the VK/KV/MTP + backends based on state_layout, state dtype, and T. + + Args: + q (torch.Tensor): + Query of shape ``[B, T, H, K]``. Must be float16/bfloat16. + k (torch.Tensor): + Key of shape ``[B, T, H, K]``. Must be float16/bfloat16. + v (torch.Tensor): + Value of shape ``[B, T, HV, V]``. Must be float16/bfloat16. + state (torch.Tensor): + State ``[B_or_pool, HV, V, K]`` if state_layout="VK", else ``[B_or_pool, HV, K, V]``. + A_log (torch.Tensor): + Log decay of shape ``[HV]``. Must be float32. + a (torch.Tensor): + Input-dependent decay of shape ``[B, T, HV]``. + dt_bias (torch.Tensor): + Decay bias of shape ``[HV]``. Must be float32. + b (torch.Tensor): + Update gate of shape ``[B, T, HV]``. + state_layout (str): + "VK" (K-last) or "KV" (K-major). Default "VK". + state_indices (Optional[torch.Tensor]): + Optional ``[B]`` int32/int64; when set, state is a pool and indices map batch + to slot. All values must be in ``[0, pool_size)``; negative values (padding) + are not supported and will raise ValueError. + scale (Optional[float]): + Scale for queries; None => 1/sqrt(K). + output (Optional[torch.Tensor]): + Pre-allocated output ``[B, T, HV, V]`` or None. + intermediate_states_buffer (Optional[torch.Tensor]): + Optional ``[pool, T_cache, HV, V, K]`` for MTP rollback. + disable_state_update (bool): + If True, state is not updated (read-only). Only affects MTP. Default False. + use_qk_l2norm (bool): + Whether to L2-normalize q and k. Default True. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - output: Output tensor of shape ``[B, T, HV, V]`` + - state: Updated state (same tensor as input, modified in-place) + + Dispatch: + - state_layout="VK", state bf16, T in {1..4}, K=V=128 -> pretranspose bf16 (pool optional). + - state_layout="VK", state fp32, T=1 -> pretranspose fp32 (no pool). + - state_layout="VK", state fp32, T>1 -> MTP (pool required via state_indices). + - state_layout="KV", state fp32, T=1 -> KV decode (no pool). + - Other combinations raise with a clear error. + + Note: + - Requires SM90+ (Hopper, Blackwell, etc.). All backends are JIT-compiled and tested on SM90/100/110/120. + - State is updated in-place; with pool (state_indices), updates write into the state tensor. + """ + B, T, H, K = q.shape + _, _, HV, V = v.shape + + if state_layout not in ("VK", "KV"): + raise ValueError(f"state_layout must be 'VK' or 'KV', got {state_layout!r}") + + use_pool = state_indices is not None + if state_layout == "KV": + if use_pool: + raise NotImplementedError( + "state_indices (pool) is not supported for state_layout='KV' yet" + ) + if T != 1: + raise ValueError(f"state_layout='KV' only supports T=1, got T={T}") + if state.dtype != torch.float32: + raise ValueError( + f"state_layout='KV' requires float32 state, got {state.dtype}" + ) + # KV decode: state [B, HV, K, V] + if state.shape != (B, HV, K, V): + raise ValueError( + f"Expected state shape [B={B}, HV={HV}, K={K}, V={V}] for KV layout, got {state.shape}" + ) + return _gated_delta_rule_decode_kv_impl( + q=q, + k=k, + v=v, + state=state, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=scale, + output=output, + use_qk_l2norm=use_qk_l2norm, + ) + + # state_layout == "VK" + if state.dtype == torch.bfloat16: + if T not in (1, 2, 3, 4) or K != 128 or V != 128: + raise ValueError( + f"VK bf16 path requires T in {{1,2,3,4}} and K=V=128, got T={T}, K={K}, V={V}" + ) + if use_pool: + if disable_state_update or intermediate_states_buffer is not None: + raise NotImplementedError( + "VK bf16 path with state_indices (pool) does not support " + "disable_state_update or intermediate_states_buffer; use fp32 state for MTP." + ) + pool_size = state.shape[0] + if state.shape != (pool_size, HV, V, K): + raise ValueError( + f"Expected state [pool_size, HV, V, K], got {state.shape}" + ) + if state_indices.shape != (B,): + raise ValueError( + f"state_indices must be [B={B}], got {state_indices.shape}" + ) + _check_state_indices_bounds(state_indices, pool_size) + return _gated_delta_rule_decode_pretranspose_impl( + q=q, + k=k, + v=v, + state=None, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=scale, + output=output, + use_qk_l2norm=use_qk_l2norm, + initial_state=state, + initial_state_indices=state_indices, + ) + else: + if state.shape != (B, HV, V, K): + raise ValueError( + f"Expected state [B={B}, HV={HV}, V={V}, K={K}] for VK, got {state.shape}" + ) + return _gated_delta_rule_decode_pretranspose_impl( + q=q, + k=k, + v=v, + state=state, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=scale, + output=output, + use_qk_l2norm=use_qk_l2norm, + ) + + # state_layout == "VK", float32 + if state.dtype != torch.float32: + raise ValueError( + f"VK layout supports bfloat16 or float32 state, got {state.dtype}" + ) + if T == 1: + if use_pool: + raise NotImplementedError( + "VK fp32 T=1 with state_indices (pool) is not implemented yet" + ) + if state.shape != (B, HV, V, K): + raise ValueError( + f"Expected state [B={B}, HV={HV}, V={V}, K={K}] for VK, got {state.shape}" + ) + return _gated_delta_rule_decode_pretranspose_impl( + q=q, + k=k, + v=v, + state=state, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=scale, + output=output, + use_qk_l2norm=use_qk_l2norm, + ) + # T > 1: MTP path; requires pool + if not use_pool: + raise ValueError( + "VK fp32 MTP (T>1) requires state_indices and state as pool [pool_size, HV, V, K]" + ) + pool_size = state.shape[0] + if state.shape != (pool_size, HV, V, K): + raise ValueError( + f"Expected state [pool_size, HV, V, K] for VK MTP, got {state.shape}" + ) + if state_indices.shape != (B,): + raise ValueError(f"state_indices must be [B={B}], got {state_indices.shape}") + _check_state_indices_bounds(state_indices, pool_size) + return _gated_delta_rule_mtp_impl( + q=q, + k=k, + v=v, + initial_state=state, + initial_state_indices=state_indices, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=scale, + output=output, + intermediate_states_buffer=intermediate_states_buffer, + disable_state_update=disable_state_update, + use_qk_l2norm=use_qk_l2norm, + ) + + +# ============================================================================ +# Deprecation shims: legacy names delegate to unified API (RFC 5.7) +# ============================================================================ + + +@flashinfer_api +def gated_delta_rule_decode_pretranspose( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + state: Optional[torch.Tensor], + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + b: torch.Tensor, + scale: Optional[float] = None, + output: Optional[torch.Tensor] = None, + use_qk_l2norm: bool = True, + initial_state: Optional[torch.Tensor] = None, + initial_state_indices: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Deprecated: use gated_delta_rule_decode(..., state_layout=\"VK\") instead.""" + warnings.warn( + "gated_delta_rule_decode_pretranspose is deprecated and will be removed in a future " + "version. Use gated_delta_rule_decode(..., state_layout='VK') instead.", + DeprecationWarning, + stacklevel=2, + ) + use_pool = initial_state is not None + if use_pool != (initial_state_indices is not None): + raise ValueError( + "initial_state and initial_state_indices must be provided together" + ) + if state is None and initial_state is None: + raise ValueError("Either state or initial_state must be provided") + if initial_state is not None: + return gated_delta_rule_decode( + q=q, + k=k, + v=v, + state=initial_state, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + state_layout="VK", + state_indices=initial_state_indices, + scale=scale, + output=output, + use_qk_l2norm=use_qk_l2norm, + ) + return gated_delta_rule_decode( + q=q, + k=k, + v=v, + state=state, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + state_layout="VK", + scale=scale, + output=output, + use_qk_l2norm=use_qk_l2norm, + ) + + +@flashinfer_api +def gated_delta_rule_decode_kv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + state: torch.Tensor, + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + b: torch.Tensor, + scale: Optional[float] = None, + output: Optional[torch.Tensor] = None, + use_qk_l2norm: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Deprecated: use gated_delta_rule_decode(..., state_layout=\"KV\") instead.""" + warnings.warn( + "gated_delta_rule_decode_kv is deprecated and will be removed in a future " + "version. Use gated_delta_rule_decode(..., state_layout='KV') instead.", + DeprecationWarning, + stacklevel=2, + ) + return gated_delta_rule_decode( + q=q, + k=k, + v=v, + state=state, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + state_layout="KV", + scale=scale, + output=output, + use_qk_l2norm=use_qk_l2norm, + ) + + +@flashinfer_api +def gated_delta_rule_mtp( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + initial_state: torch.Tensor, + initial_state_indices: torch.Tensor, + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + b: torch.Tensor, + scale: Optional[float] = None, + output: Optional[torch.Tensor] = None, + intermediate_states_buffer: Optional[torch.Tensor] = None, + disable_state_update: bool = True, + use_qk_l2norm: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Deprecated: use gated_delta_rule_decode(..., state_layout=\"VK\", state_indices=...) instead.""" + warnings.warn( + "gated_delta_rule_mtp is deprecated and will be removed in a future version. " + "Use gated_delta_rule_decode(..., state_layout='VK', state_indices=...) instead.", + DeprecationWarning, + stacklevel=2, + ) + return gated_delta_rule_decode( + q=q, + k=k, + v=v, + state=initial_state, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + state_layout="VK", + state_indices=initial_state_indices, + scale=scale, + output=output, + intermediate_states_buffer=intermediate_states_buffer, + disable_state_update=disable_state_update, + use_qk_l2norm=use_qk_l2norm, + ) diff --git a/tests/gdn/test_decode_delta_rule.py b/tests/gdn/test_decode_delta_rule.py index 1b43a0ddfe..545e9429fc 100644 --- a/tests/gdn/test_decode_delta_rule.py +++ b/tests/gdn/test_decode_delta_rule.py @@ -37,7 +37,7 @@ # Import the actual decode functions from flashinfer.gdn_decode import ( gated_delta_rule_decode_pretranspose, - gated_delta_rule_decode, + gated_delta_rule_decode_kv, gated_delta_rule_mtp, ) from flashinfer.utils import get_compute_capability @@ -371,7 +371,7 @@ def _test_decode_kernel_nontranspose( # Call kernel (nontranspose version uses K-major layout directly, no transpose needed) our_state = input_state.clone() - our_o, our_state = gated_delta_rule_decode( + our_o, our_state = gated_delta_rule_decode_kv( q=q, k=k, v=v, diff --git a/tests/gdn/test_gdn_decode.py b/tests/gdn/test_gdn_decode.py new file mode 100644 index 0000000000..299e5a9949 --- /dev/null +++ b/tests/gdn/test_gdn_decode.py @@ -0,0 +1,626 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +"""Tests for the GDN decode API (gated_delta_rule_decode).""" + +import random + +import torch +import pytest + +from flashinfer.gdn_decode import ( + gated_delta_rule_decode, + gated_delta_rule_decode_pretranspose, + _gated_delta_rule_decode_pretranspose_impl, + _gated_delta_rule_decode_kv_impl, + _gated_delta_rule_mtp_impl, +) +from flashinfer.utils import get_compute_capability + + +def _skip_if_not_sm90_or_later(): + cc = get_compute_capability(torch.device("cuda")) + if cc[0] not in [9, 10, 11, 12]: + pytest.skip(f"GDN requires SM90+, got SM{cc[0]}{cc[1]}") + + +def _make_qkv_state_and_params( + B: int, + T: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + state_dtype: torch.dtype, + q_dtype: torch.dtype, + seed: int | None = None, +): + """Create q, k, v, state (VK layout), A_log, a, dt_bias, b on CUDA.""" + if seed is not None: + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + device = torch.device("cuda") + q = torch.randn(B, T, num_q_heads, head_size, dtype=q_dtype, device=device) * 0.1 + k = torch.randn(B, T, num_k_heads, head_size, dtype=q_dtype, device=device) * 0.1 + k = torch.nn.functional.normalize(k, p=2.0, dim=-1) + v = torch.randn(B, T, num_v_heads, head_size, dtype=q_dtype, device=device) * 0.1 + # VK state: [B, HV, V, K] + state_vk = ( + torch.randn( + B, num_v_heads, head_size, head_size, dtype=state_dtype, device=device + ) + * 0.1 + ) + A_log = torch.randn(num_v_heads, dtype=torch.float32, device=device) * 0.1 + dt_bias = torch.randn(num_v_heads, dtype=torch.float32, device=device) * 0.1 + a = torch.randn(B, T, num_v_heads, dtype=q_dtype, device=device) * 0.1 + b = torch.randn(B, T, num_v_heads, dtype=q_dtype, device=device) * 0.1 + return q, k, v, state_vk, A_log, a, dt_bias, b + + +# ----------------------------------------------------------------------------- +# gated_delta_rule_decode vs legacy impl: same inputs, compare outputs +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize("batch_size", [2, 8]) +@pytest.mark.parametrize("state_dtype", [torch.bfloat16]) +def test_gated_delta_rule_decode_vk_bf16_t1_no_pool_matches_pretranspose( + batch_size: int, state_dtype: torch.dtype +): + """gated_delta_rule_decode (VK, bf16, T=1, no pool) should match pretranspose impl.""" + _skip_if_not_sm90_or_later() + B, T = batch_size, 1 + num_q, num_k, num_v, head_size = 16, 16, 32, 128 + seed = 42 + q, k, v, state_vk, A_log, a, dt_bias, b = _make_qkv_state_and_params( + B, T, num_q, num_k, num_v, head_size, state_dtype, torch.bfloat16, seed=seed + ) + state_legacy = state_vk.clone() + state_unified = state_vk.clone() + + out_legacy, _ = _gated_delta_rule_decode_pretranspose_impl( + q=q, + k=k, + v=v, + state=state_legacy, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=None, + use_qk_l2norm=True, + ) + out_unified, _ = gated_delta_rule_decode( + q=q, + k=k, + v=v, + state=state_unified, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + state_layout="VK", + scale=None, + use_qk_l2norm=True, + ) + + torch.testing.assert_close(out_unified, out_legacy, atol=1e-4, rtol=1e-3) + torch.testing.assert_close(state_unified, state_legacy, atol=1e-4, rtol=1e-3) + + +@pytest.mark.parametrize("batch_size", [2, 8]) +def test_gated_delta_rule_decode_vk_fp32_t1_no_pool_matches_pretranspose(batch_size: int): + """gated_delta_rule_decode (VK, fp32, T=1, no pool) should match pretranspose impl.""" + _skip_if_not_sm90_or_later() + B, T = batch_size, 1 + num_q, num_k, num_v, head_size = 16, 16, 32, 128 + seed = 43 + q, k, v, state_vk, A_log, a, dt_bias, b = _make_qkv_state_and_params( + B, T, num_q, num_k, num_v, head_size, torch.float32, torch.bfloat16, seed=seed + ) + state_legacy = state_vk.clone() + state_unified = state_vk.clone() + + out_legacy, _ = _gated_delta_rule_decode_pretranspose_impl( + q=q, + k=k, + v=v, + state=state_legacy, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=None, + use_qk_l2norm=True, + ) + out_unified, _ = gated_delta_rule_decode( + q=q, + k=k, + v=v, + state=state_unified, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + state_layout="VK", + scale=None, + use_qk_l2norm=True, + ) + + torch.testing.assert_close(out_unified, out_legacy, atol=5e-3, rtol=5e-3) + torch.testing.assert_close(state_unified, state_legacy, atol=5e-3, rtol=5e-3) + + +@pytest.mark.parametrize("batch_size", [2, 8]) +def test_gated_delta_rule_decode_kv_fp32_t1_matches_decode_kv(batch_size: int): + """gated_delta_rule_decode (KV, fp32, T=1) should match KV impl.""" + _skip_if_not_sm90_or_later() + B, T = batch_size, 1 + num_q, num_k, num_v, head_size = 16, 16, 32, 128 + seed = 44 + q, k, v, state_vk, A_log, a, dt_bias, b = _make_qkv_state_and_params( + B, T, num_q, num_k, num_v, head_size, torch.float32, torch.bfloat16, seed=seed + ) + # KV layout: [B, HV, K, V] + state_kv = state_vk.permute(0, 1, 3, 2).contiguous() + state_legacy = state_kv.clone() + state_unified = state_kv.clone() + + out_legacy, _ = _gated_delta_rule_decode_kv_impl( + q=q, + k=k, + v=v, + state=state_legacy, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=None, + use_qk_l2norm=True, + ) + out_unified, _ = gated_delta_rule_decode( + q=q, + k=k, + v=v, + state=state_unified, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + state_layout="KV", + scale=None, + use_qk_l2norm=True, + ) + + torch.testing.assert_close(out_unified, out_legacy, atol=5e-3, rtol=5e-3) + torch.testing.assert_close(state_unified, state_legacy, atol=5e-3, rtol=5e-3) + + +@pytest.mark.parametrize("batch_size", [2, 4]) +@pytest.mark.parametrize("T", [2, 3]) +def test_gated_delta_rule_decode_vk_fp32_mtp_matches_mtp(batch_size: int, T: int): + """gated_delta_rule_decode (VK, fp32, T>1, pool) should match MTP impl.""" + _skip_if_not_sm90_or_later() + B = batch_size + num_q, num_k, num_v, head_size = 16, 16, 32, 128 + pool_size = B + 2 + seed = 45 + device = torch.device("cuda") + if seed is not None: + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + q = torch.randn(B, T, num_q, head_size, dtype=torch.bfloat16, device=device) * 0.1 + k = torch.randn(B, T, num_k, head_size, dtype=torch.bfloat16, device=device) * 0.1 + k = torch.nn.functional.normalize(k, p=2.0, dim=-1) + v = torch.randn(B, T, num_v, head_size, dtype=torch.bfloat16, device=device) * 0.1 + state_pool = ( + torch.randn( + pool_size, num_v, head_size, head_size, dtype=torch.float32, device=device + ) + * 0.1 + ) + state_indices = torch.arange(B, dtype=torch.int32, device=device) + A_log = torch.randn(num_v, dtype=torch.float32, device=device) * 0.1 + dt_bias = torch.randn(num_v, dtype=torch.float32, device=device) * 0.1 + a = torch.randn(B, T, num_v, dtype=torch.bfloat16, device=device) * 0.1 + b = torch.randn(B, T, num_v, dtype=torch.bfloat16, device=device) * 0.1 + + pool_legacy = state_pool.clone() + pool_unified = state_pool.clone() + + out_legacy, _ = _gated_delta_rule_mtp_impl( + q=q, + k=k, + v=v, + initial_state=pool_legacy, + initial_state_indices=state_indices, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=None, + disable_state_update=False, + use_qk_l2norm=True, + ) + out_unified, _ = gated_delta_rule_decode( + q=q, + k=k, + v=v, + state=pool_unified, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + state_layout="VK", + state_indices=state_indices, + scale=None, + disable_state_update=False, + use_qk_l2norm=True, + ) + + torch.testing.assert_close(out_unified, out_legacy, atol=5e-3, rtol=5e-3) + torch.testing.assert_close(pool_unified, pool_legacy, atol=5e-3, rtol=5e-3) + + +def test_gated_delta_rule_decode_vk_fp32_mtp_with_intermediate_buffer_matches_mtp(): + """gated_delta_rule_decode with intermediate_states_buffer should match MTP impl.""" + _skip_if_not_sm90_or_later() + B, T, num_q, num_k, num_v, head_size = 4, 2, 16, 16, 32, 128 + pool_size = B + 2 + cache_steps = T + seed = 46 + device = torch.device("cuda") + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + q = torch.randn(B, T, num_q, head_size, dtype=torch.bfloat16, device=device) * 0.1 + k = torch.randn(B, T, num_k, head_size, dtype=torch.bfloat16, device=device) * 0.1 + k = torch.nn.functional.normalize(k, p=2.0, dim=-1) + v = torch.randn(B, T, num_v, head_size, dtype=torch.bfloat16, device=device) * 0.1 + state_pool = ( + torch.randn( + pool_size, num_v, head_size, head_size, dtype=torch.float32, device=device + ) + * 0.1 + ) + state_indices = torch.arange(B, dtype=torch.int32, device=device) + A_log = torch.randn(num_v, dtype=torch.float32, device=device) * 0.1 + dt_bias = torch.randn(num_v, dtype=torch.float32, device=device) * 0.1 + a = torch.randn(B, T, num_v, dtype=torch.bfloat16, device=device) * 0.1 + b = torch.randn(B, T, num_v, dtype=torch.bfloat16, device=device) * 0.1 + intermed_buf = torch.zeros( + pool_size, + cache_steps, + num_v, + head_size, + head_size, + dtype=torch.float32, + device=device, + ) + + pool_legacy = state_pool.clone() + pool_unified = state_pool.clone() + intermed_legacy = intermed_buf.clone() + intermed_unified = intermed_buf.clone() + + out_legacy, _ = _gated_delta_rule_mtp_impl( + q=q, + k=k, + v=v, + initial_state=pool_legacy, + initial_state_indices=state_indices, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=None, + disable_state_update=False, + use_qk_l2norm=True, + intermediate_states_buffer=intermed_legacy, + ) + out_unified, _ = gated_delta_rule_decode( + q=q, + k=k, + v=v, + state=pool_unified, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + state_layout="VK", + state_indices=state_indices, + scale=None, + disable_state_update=False, + use_qk_l2norm=True, + intermediate_states_buffer=intermed_unified, + ) + + torch.testing.assert_close(out_unified, out_legacy, atol=5e-3, rtol=5e-3) + torch.testing.assert_close(pool_unified, pool_legacy, atol=5e-3, rtol=5e-3) + torch.testing.assert_close( + intermed_unified, intermed_legacy, atol=5e-3, rtol=5e-3 + ) + + +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("pool_size", [1, 4]) +def test_gated_delta_rule_decode_vk_fp32_mtp_edge_pool_size_and_b1(batch_size: int, pool_size: int): + """Edge cases: B=1 and/or pool_size=1 (MTP path).""" + _skip_if_not_sm90_or_later() + B, T = batch_size, 2 + num_q, num_k, num_v, head_size = 16, 16, 32, 128 + seed = 47 + device = torch.device("cuda") + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + q = torch.randn(B, T, num_q, head_size, dtype=torch.bfloat16, device=device) * 0.1 + k = torch.randn(B, T, num_k, head_size, dtype=torch.bfloat16, device=device) * 0.1 + k = torch.nn.functional.normalize(k, p=2.0, dim=-1) + v = torch.randn(B, T, num_v, head_size, dtype=torch.bfloat16, device=device) * 0.1 + state_pool = ( + torch.randn( + pool_size, num_v, head_size, head_size, dtype=torch.float32, device=device + ) + * 0.1 + ) + state_indices = torch.arange(B, dtype=torch.int32, device=device) + A_log = torch.randn(num_v, dtype=torch.float32, device=device) * 0.1 + dt_bias = torch.randn(num_v, dtype=torch.float32, device=device) * 0.1 + a = torch.randn(B, T, num_v, dtype=torch.bfloat16, device=device) * 0.1 + b = torch.randn(B, T, num_v, dtype=torch.bfloat16, device=device) * 0.1 + + pool_legacy = state_pool.clone() + pool_unified = state_pool.clone() + + out_legacy, _ = _gated_delta_rule_mtp_impl( + q=q, + k=k, + v=v, + initial_state=pool_legacy, + initial_state_indices=state_indices, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=None, + disable_state_update=False, + use_qk_l2norm=True, + ) + out_unified, _ = gated_delta_rule_decode( + q=q, + k=k, + v=v, + state=pool_unified, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + state_layout="VK", + state_indices=state_indices, + scale=None, + disable_state_update=False, + use_qk_l2norm=True, + ) + + torch.testing.assert_close(out_unified, out_legacy, atol=5e-3, rtol=5e-3) + torch.testing.assert_close(pool_unified, pool_legacy, atol=5e-3, rtol=5e-3) + + +# ----------------------------------------------------------------------------- +# Error handling: unsupported combinations +# ----------------------------------------------------------------------------- + + +def test_gated_delta_rule_decode_invalid_state_layout_raises(): + _skip_if_not_sm90_or_later() + B, T, H, HV, K, V = 2, 1, 16, 32, 128, 128 + device = torch.device("cuda") + q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + v = torch.randn(B, T, HV, V, dtype=torch.bfloat16, device=device) + state = torch.randn(B, HV, V, K, dtype=torch.bfloat16, device=device) + A_log = torch.randn(HV, dtype=torch.float32, device=device) + a = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + dt_bias = torch.randn(HV, dtype=torch.float32, device=device) + b = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + + with pytest.raises(ValueError, match="state_layout must be"): + gated_delta_rule_decode( + q, k, v, state, A_log, a, dt_bias, b, state_layout="invalid" + ) + + +def test_gated_delta_rule_decode_kv_with_state_indices_raises(): + _skip_if_not_sm90_or_later() + B, T, H, HV, K, V = 2, 1, 16, 32, 128, 128 + device = torch.device("cuda") + q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + v = torch.randn(B, T, HV, V, dtype=torch.bfloat16, device=device) + state = torch.randn(B, HV, K, V, dtype=torch.float32, device=device) + state_indices = torch.arange(B, dtype=torch.int32, device=device) + A_log = torch.randn(HV, dtype=torch.float32, device=device) + a = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + dt_bias = torch.randn(HV, dtype=torch.float32, device=device) + b = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + + with pytest.raises(NotImplementedError, match="state_indices.*KV"): + gated_delta_rule_decode( + q, + k, + v, + state, + A_log, + a, + dt_bias, + b, + state_layout="KV", + state_indices=state_indices, + ) + + +def test_gated_delta_rule_decode_state_indices_out_of_bounds_raises(): + """state_indices must be in [0, pool_size); out-of-range raises ValueError.""" + _skip_if_not_sm90_or_later() + B, T, H, HV, K, V = 2, 1, 16, 32, 128, 128 + pool_size = 4 + device = torch.device("cuda") + q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + v = torch.randn(B, T, HV, V, dtype=torch.bfloat16, device=device) + state = torch.randn(pool_size, HV, V, K, dtype=torch.bfloat16, device=device) + A_log = torch.randn(HV, dtype=torch.float32, device=device) + a = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + dt_bias = torch.randn(HV, dtype=torch.float32, device=device) + b = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + + # index >= pool_size + state_indices = torch.tensor([0, pool_size], dtype=torch.int32, device=device) + with pytest.raises(ValueError, match="state_indices must be in \\[0, pool_size"): + gated_delta_rule_decode( + q, k, v, state, A_log, a, dt_bias, b, + state_layout="VK", state_indices=state_indices, + ) + + # negative index + state_indices_neg = torch.tensor([0, -1], dtype=torch.int32, device=device) + with pytest.raises(ValueError, match="state_indices must be in \\[0, pool_size"): + gated_delta_rule_decode( + q, k, v, state, A_log, a, dt_bias, b, + state_layout="VK", state_indices=state_indices_neg, + ) + + +def test_gated_delta_rule_decode_vk_fp32_t_gt_1_without_pool_raises(): + _skip_if_not_sm90_or_later() + B, T, H, HV, K, V = 2, 3, 16, 32, 128, 128 + device = torch.device("cuda") + q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + v = torch.randn(B, T, HV, V, dtype=torch.bfloat16, device=device) + # state as per-batch (no pool) but T>1 -> should require pool + state = torch.randn(B, HV, V, K, dtype=torch.float32, device=device) + A_log = torch.randn(HV, dtype=torch.float32, device=device) + a = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + dt_bias = torch.randn(HV, dtype=torch.float32, device=device) + b = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + + with pytest.raises(ValueError, match="state_indices and state as pool"): + gated_delta_rule_decode( + q, + k, + v, + state, + A_log, + a, + dt_bias, + b, + state_layout="VK", + ) + + +def test_gated_delta_rule_decode_kv_t_gt_1_raises(): + _skip_if_not_sm90_or_later() + B, T, H, HV, K, V = 2, 2, 16, 32, 128, 128 + device = torch.device("cuda") + q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + v = torch.randn(B, T, HV, V, dtype=torch.bfloat16, device=device) + state = torch.randn(B, HV, K, V, dtype=torch.float32, device=device) + A_log = torch.randn(HV, dtype=torch.float32, device=device) + a = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + dt_bias = torch.randn(HV, dtype=torch.float32, device=device) + b = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + + with pytest.raises(ValueError, match="state_layout='KV' only supports T=1"): + gated_delta_rule_decode( + q, k, v, state, A_log, a, dt_bias, b, state_layout="KV" + ) + + +def test_gated_delta_rule_decode_bf16_pool_unsupported_options_raises(): + """VK bf16 path with pool does not support disable_state_update or intermediate_states_buffer.""" + _skip_if_not_sm90_or_later() + B, T, H, HV, K, V = 2, 1, 16, 32, 128, 128 + pool_size = 4 + device = torch.device("cuda") + q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + v = torch.randn(B, T, HV, V, dtype=torch.bfloat16, device=device) + state = torch.randn(pool_size, HV, V, K, dtype=torch.bfloat16, device=device) + state_indices = torch.arange(B, dtype=torch.int32, device=device) + A_log = torch.randn(HV, dtype=torch.float32, device=device) + a = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + dt_bias = torch.randn(HV, dtype=torch.float32, device=device) + b = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + intermed = torch.zeros( + pool_size, T, HV, V, K, dtype=torch.float32, device=device + ) + + with pytest.raises(NotImplementedError, match="disable_state_update"): + gated_delta_rule_decode( + q, k, v, state, A_log, a, dt_bias, b, + state_layout="VK", state_indices=state_indices, + disable_state_update=True, + ) + with pytest.raises(NotImplementedError, match="intermediate_states_buffer"): + gated_delta_rule_decode( + q, k, v, state, A_log, a, dt_bias, b, + state_layout="VK", state_indices=state_indices, + intermediate_states_buffer=intermed, + ) + + +def test_pretranspose_shim_validation_raises(): + """Legacy gated_delta_rule_decode_pretranspose preserves initial_state/state validation.""" + _skip_if_not_sm90_or_later() + B, T, H, HV, K, V = 2, 1, 16, 32, 128, 128 + device = torch.device("cuda") + q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device=device) + v = torch.randn(B, T, HV, V, dtype=torch.bfloat16, device=device) + state = torch.randn(B, HV, V, K, dtype=torch.bfloat16, device=device) + pool = torch.randn(4, HV, V, K, dtype=torch.bfloat16, device=device) + indices = torch.arange(B, dtype=torch.int32, device=device) + A_log = torch.randn(HV, dtype=torch.float32, device=device) + a = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + dt_bias = torch.randn(HV, dtype=torch.float32, device=device) + b = torch.randn(B, T, HV, dtype=torch.bfloat16, device=device) + + with pytest.warns(DeprecationWarning, match="deprecated"): + with pytest.raises(ValueError, match="provided together"): + gated_delta_rule_decode_pretranspose( + q, k, v, state=None, A_log=A_log, a=a, dt_bias=dt_bias, b=b, + initial_state=pool, initial_state_indices=None, + ) + with pytest.warns(DeprecationWarning, match="deprecated"): + with pytest.raises(ValueError, match="provided together"): + gated_delta_rule_decode_pretranspose( + q, k, v, state=state, A_log=A_log, a=a, dt_bias=dt_bias, b=b, + initial_state=None, initial_state_indices=indices, + ) + with pytest.warns(DeprecationWarning, match="deprecated"): + with pytest.raises(ValueError, match="Either state or initial_state"): + gated_delta_rule_decode_pretranspose( + q, k, v, state=None, A_log=A_log, a=a, dt_bias=dt_bias, b=b, + initial_state=None, initial_state_indices=None, + )