Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions benchmarks/bench_gdn_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
except ImportError:
GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = False


# ============================================================================
# Utility Functions
# ============================================================================
Expand Down Expand Up @@ -1167,6 +1166,7 @@ def bench_gdn_mtp(
use_beta: bool = True,
use_qk_l2norm: bool = True,
cache_intermediate_states: bool = True,
disable_state_update: bool = True,
warmup_iters: int = 10,
bench_iters: int = 100,
):
Expand Down Expand Up @@ -1243,7 +1243,7 @@ def bench_gdn_mtp(
scale,
output,
intermediate_states_buffer,
disable_state_update=True,
disable_state_update=disable_state_update,
use_qk_l2norm=use_qk_l2norm,
),
enable_cupti=True,
Expand All @@ -1264,7 +1264,7 @@ def bench_gdn_mtp(
head_size,
dtype,
seq_len,
disable_state_update=True, # MTP mode: state is not written back
disable_state_update=disable_state_update,
)

kernel_tflops = flops / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0
Expand Down Expand Up @@ -1577,7 +1577,7 @@ def bench_mtp_comparison(
scale,
output_fi,
intermediate_fi,
disable_state_update=True,
disable_state_update=False,
use_qk_l2norm=use_qk_l2norm,
),
enable_cupti=True,
Expand Down Expand Up @@ -1627,7 +1627,7 @@ def bench_mtp_comparison(
scale,
output_tr,
intermediate_tr,
disable_state_update=True,
disable_state_update=False,
use_qk_l2norm=use_qk_l2norm,
),
enable_cupti=True,
Expand Down Expand Up @@ -2399,7 +2399,9 @@ def run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm):
f"\nGDN MTP Benchmark "
f"(heads: q={args.num_q_heads}, k={args.num_k_heads}, "
f"v={args.num_v_heads}, d={args.head_size}, dtype={args.dtype}, "
f"qk_l2norm={'ON' if use_qk_l2norm else 'OFF'})"
f"qk_l2norm={'ON' if use_qk_l2norm else 'OFF'}, "
f"cache_intermediate={'ON' if args.cache_intermediate_states else 'OFF'}, "
f"update_state={'ON' if args.update_state else 'OFF'})"
)
print("-" * 100)
print(
Expand All @@ -2419,6 +2421,7 @@ def run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm):
dtype=dtype,
use_qk_l2norm=use_qk_l2norm,
cache_intermediate_states=args.cache_intermediate_states,
disable_state_update=not args.update_state,
warmup_iters=args.warmup,
bench_iters=args.iters,
)
Expand Down Expand Up @@ -2736,6 +2739,11 @@ def main():
action="store_true",
help="Cache intermediate states for MTP benchmark",
)
parser.add_argument(
"--update-state",
action="store_true",
help="Update final state (disable_state_update=False) for MTP benchmark",
)
parser.add_argument(
"--warmup",
type=int,
Expand Down
Loading
Loading