Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions benchmarks/bench_gdn_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

from flashinfer.gdn_decode import (
gated_delta_rule_decode_pretranspose,
gated_delta_rule_decode,
gated_delta_rule_decode_kv,
gated_delta_rule_mtp,
)
from flashinfer.testing import bench_gpu_time
Expand Down Expand Up @@ -1111,7 +1111,7 @@ def bench_gdn_decode(
if version == "pretranspose":
decode_func = gated_delta_rule_decode_pretranspose
elif version == "nontranspose":
decode_func = gated_delta_rule_decode
decode_func = gated_delta_rule_decode_kv
else:
raise ValueError(f"Unknown version: {version}")

Expand Down Expand Up @@ -1334,7 +1334,7 @@ def bench_comparison(
)

flashinfer_times = bench_gpu_time(
lambda: gated_delta_rule_decode(
lambda: gated_delta_rule_decode_kv(
q, k, v, state_fi, A_log, a, dt_bias, b, scale, output_fi, use_qk_l2norm
),
enable_cupti=True,
Expand Down Expand Up @@ -1709,7 +1709,7 @@ def verify_correctness(
output_fi = torch.empty(
batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda"
)
gated_delta_rule_decode(
gated_delta_rule_decode_kv(
q, k, v, state_fi, A_log, a, dt_bias, b, scale, output_fi, use_qk_l2norm
)

Expand Down Expand Up @@ -1983,7 +1983,7 @@ def bench_all_layouts(

try:
times = bench_gpu_time(
lambda: gated_delta_rule_decode(
lambda: gated_delta_rule_decode_kv(
q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm
),
enable_cupti=True,
Expand Down
3 changes: 3 additions & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
cute_dsl_fused_moe_nvfp4 as cute_dsl_fused_moe_nvfp4,
CuteDslMoEWrapper as CuteDslMoEWrapper,
)
from .gdn_decode import (
gated_delta_rule_decode as gated_delta_rule_decode,
)
from .gdn_prefill import chunk_gated_delta_rule as chunk_gated_delta_rule
from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper
from .gemm import bmm_bf16 as bmm_bf16
Expand Down
Loading
Loading