From 66e738c8fbb8be7deed63598b99d2fbc28e558e3 Mon Sep 17 00:00:00 2001 From: Artem Perevedentsev Date: Thu, 23 Apr 2026 17:13:44 +0300 Subject: [PATCH 1/3] fix(gdn): use physical SM count for SM100 persistent prefill kernel --- flashinfer/gdn_kernels/blackwell/gdn_prefill.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py index 82dcc72b07..cc8151bfcb 100644 --- a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py +++ b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py @@ -33,9 +33,10 @@ import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -import cutlass.utils as cutlass_utils from cutlass.cute.runtime import from_dlpack +from flashinfer.cute_dsl.utils import get_max_active_clusters, get_num_sm + from .gated_delta_net_chunked import GatedDeltaNetChunkedKernel @@ -157,9 +158,8 @@ def chunk_gated_delta_rule_sm100( if "compiled" not in cache: # --- First call: compile the kernel --- - hardware_info = cutlass_utils.HardwareInfo() - num_sm = hardware_info.get_max_active_clusters(1) - max_active_clusters = hardware_info.get_max_active_clusters(1) + num_sm = get_num_sm(q.device) + max_active_clusters = min(get_max_active_clusters(1), num_sm) gdn = GatedDeltaNetChunkedKernel( io_dtype=io_dtype, From 1f337a3ba7435e6b27ed81d9a829075a318a229c Mon Sep 17 00:00:00 2001 From: Artem Perevedentsev Date: Thu, 23 Apr 2026 17:39:19 +0300 Subject: [PATCH 2/3] Update flashinfer/gdn_kernels/blackwell/gdn_prefill.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- flashinfer/gdn_kernels/blackwell/gdn_prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py index cc8151bfcb..078c949c83 100644 --- a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py +++ b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py @@ -159,7 +159,7 @@ def chunk_gated_delta_rule_sm100( if "compiled" not in cache: # --- First call: compile the kernel --- num_sm = get_num_sm(q.device) - max_active_clusters = min(get_max_active_clusters(1), num_sm) + max_active_clusters = min(get_max_active_clusters(1) or num_sm, num_sm) gdn = GatedDeltaNetChunkedKernel( io_dtype=io_dtype, From c7f10201fb1accc6536f69e7b75f03804c3dbf4f Mon Sep 17 00:00:00 2001 From: Artem Perevedentsev Date: Thu, 23 Apr 2026 17:48:07 +0300 Subject: [PATCH 3/3] fix(gdn): drop get_max_active_clusters indirection in SM100 prefill --- flashinfer/gdn_kernels/blackwell/gdn_prefill.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py index 078c949c83..aafcc6716a 100644 --- a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py +++ b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py @@ -35,7 +35,7 @@ import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack -from flashinfer.cute_dsl.utils import get_max_active_clusters, get_num_sm +from flashinfer.cute_dsl.utils import get_num_sm from .gated_delta_net_chunked import GatedDeltaNetChunkedKernel @@ -159,7 +159,7 @@ def chunk_gated_delta_rule_sm100( if "compiled" not in cache: # --- First call: compile the kernel --- num_sm = get_num_sm(q.device) - max_active_clusters = min(get_max_active_clusters(1) or num_sm, num_sm) + max_active_clusters = num_sm gdn = GatedDeltaNetChunkedKernel( io_dtype=io_dtype,