diff --git a/csrc/flat_prefill_kernel_delta_rule_sm90_extern.inc b/csrc/flat_prefill_kernel_delta_rule_sm90_extern.inc new file mode 100644 index 0000000000..4f41916565 --- /dev/null +++ b/csrc/flat_prefill_kernel_delta_rule_sm90_extern.inc @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Extern template declarations to prevent implicit instantiation in the dispatcher. +// Explicit instantiations are in separate generated files for parallel compilation. + +#pragma once + +#include +#include +#include "cutlass/arch/arch.h" + +namespace flat { + +// clang-format off + +#define FOR_EACH_BOOL_4(MACRO, ...) \ + MACRO(false, false, false, false, __VA_ARGS__) \ + MACRO(false, false, false, true, __VA_ARGS__) \ + MACRO(false, false, true, false, __VA_ARGS__) \ + MACRO(false, false, true, true, __VA_ARGS__) \ + MACRO(false, true, false, false, __VA_ARGS__) \ + MACRO(false, true, false, true, __VA_ARGS__) \ + MACRO(false, true, true, false, __VA_ARGS__) \ + MACRO(false, true, true, true, __VA_ARGS__) \ + MACRO(true, false, false, false, __VA_ARGS__) \ + MACRO(true, false, false, true, __VA_ARGS__) \ + MACRO(true, false, true, false, __VA_ARGS__) \ + MACRO(true, false, true, true, __VA_ARGS__) \ + MACRO(true, true, false, false, __VA_ARGS__) \ + MACRO(true, true, false, true, __VA_ARGS__) \ + MACRO(true, true, true, false, __VA_ARGS__) \ + MACRO(true, true, true, true, __VA_ARGS__) + +#define DECLARE_TEMPLATE_INSTANCE(is_gva, needs_beta, needs_alpha, init_state, ctype) \ +extern template void launch_delta_rule_prefill_kernel_gbai( \ + cudaStream_t, ctype*, float*, ctype const*, ctype const*, ctype const*, \ + float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, \ + int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t); + +// Extern template declarations for half +FOR_EACH_BOOL_4(DECLARE_TEMPLATE_INSTANCE, half) + +// Extern template declarations for nv_bfloat16 +FOR_EACH_BOOL_4(DECLARE_TEMPLATE_INSTANCE, nv_bfloat16) + +#undef DECLARE_TEMPLATE_INSTANCE +#undef FOR_EACH_BOOL_4 + +// clang-format on + +} // namespace flat diff --git a/csrc/gdn_prefill_launcher.cu b/csrc/gdn_prefill_launcher.cu index 5110c351e4..0819167bfb 100644 --- a/csrc/gdn_prefill_launcher.cu +++ b/csrc/gdn_prefill_launcher.cu @@ -25,7 +25,7 @@ #include #include -#include "flat/prefill/prefill_kernel.hpp" +#include "flashinfer/flat/prefill/prefill_kernel.hpp" using tvm::ffi::Optional; using tvm::ffi::TensorView; diff --git a/csrc/gdn_prefill_sm90_kernel_inst.jinja b/csrc/gdn_prefill_sm90_kernel_inst.jinja new file mode 100644 index 0000000000..9297039f1f --- /dev/null +++ b/csrc/gdn_prefill_sm90_kernel_inst.jinja @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Auto-generated file for separate compilation of GDN prefill kernel variants. +// Template parameters: dtype={{ dtype }}, is_gva={{ is_gva }}, needs_beta={{ needs_beta }}, +// needs_alpha={{ needs_alpha }}, init_state={{ init_state }} + +// CUDA type definitions for half and nv_bfloat16 +#include +#include + +// Include the header which defines the function template +// The header includes all necessary CUTLASS type definitions +#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm90.cuh" + +namespace flat { + +// Explicit template instantiation for launch_delta_rule_prefill_kernel_gbai +// Parameter types must exactly match the extern template declaration in prefill_kernel_delta_rule_sm90_extern.inc +template void launch_delta_rule_prefill_kernel_gbai<{{ is_gva }}, {{ needs_beta }}, {{ needs_alpha }}, {{ init_state }}, cutlass::arch::Sm90, {{ dtype }}, {{ dtype }}, float>( + cudaStream_t, {{ dtype }}*, float*, {{ dtype }} const*, {{ dtype }} const*, {{ dtype }} const*, + float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, + int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t); + +} // namespace flat diff --git a/csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu b/csrc/prefill_kernel_delta_rule_sm90.cu similarity index 91% rename from csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu rename to csrc/prefill_kernel_delta_rule_sm90.cu index dec8150576..da5a034b63 100644 --- a/csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu +++ b/csrc/prefill_kernel_delta_rule_sm90.cu @@ -15,7 +15,11 @@ */ #include -#include "prefill_kernel_delta_rule_sm90.cuh" +#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm90.cuh" + +// Extern template declarations prevent implicit instantiation here. +// Explicit instantiations are in separate generated files for parallel compilation. +#include "flat_prefill_kernel_delta_rule_sm90_extern.inc" namespace flat { @@ -87,6 +91,8 @@ void launch_delta_rule_prefill_kernel(cudaStream_t stream, TO* output, TState* o #undef LAUNCH } +// Explicit instantiations for the outer dispatch function only. +// The inner launch_delta_rule_prefill_kernel_gbai instantiations are in separate files. template void launch_delta_rule_prefill_kernel( cudaStream_t stream, half* output, float* state, half const* q, half const* k, half const* v, float const* input_state, float const* alpha, float const* beta, int64_t const* cu_seqlens, diff --git a/flashinfer/gdn_prefill.py b/flashinfer/gdn_prefill.py index 1eb2157314..2b10e83305 100644 --- a/flashinfer/gdn_prefill.py +++ b/flashinfer/gdn_prefill.py @@ -152,7 +152,7 @@ def chunk_gated_delta_rule( Note: - Supports GQA: ``num_q_heads > num_k_heads = num_v_heads`` - Supports GVA: ``num_v_heads > num_q_heads = num_k_heads`` - - The final state is in k-major layout ``[N, H, K, V]``. + - The final state is in k-last layout ``[N, H, V, K]``. - Requires SM90 (Hopper) architecture. """ assert cu_seqlens is not None, "cu_seqlens is required for varlen mode" diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 0a856a0cce..945c5c284e 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -417,9 +417,19 @@ def gen_jit_spec( verbose_env = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") debug = (debug_env if debug_env is not None else verbose_env) == "1" - cflags = ["-std=c++17", "-Wno-switch-bool"] + # Only add default C++ standard if not specified in extra flags + cflags_has_std = extra_cflags is not None and any( + f.startswith("-std=") for f in extra_cflags + ) + cuda_cflags_has_std = extra_cuda_cflags is not None and any( + f.startswith("-std=") for f in extra_cuda_cflags + ) + + cflags = ["-Wno-switch-bool"] + if not cflags_has_std: + cflags.insert(0, "-std=c++17") + cuda_cflags = [ - "-std=c++17", f"--threads={os.environ.get('FLASHINFER_NVCC_THREADS', '1')}", "-use_fast_math", "-DFLASHINFER_ENABLE_F16", @@ -427,6 +437,9 @@ def gen_jit_spec( "-DFLASHINFER_ENABLE_FP8_E4M3", "-DFLASHINFER_ENABLE_FP8_E5M2", ] + if not cuda_cflags_has_std: + cuda_cflags.insert(0, "-std=c++17") + if debug: cflags += ["-O0", "-g"] cuda_cflags += [ diff --git a/flashinfer/jit/gdn.py b/flashinfer/jit/gdn.py index a86038c82b..f1b357d32e 100644 --- a/flashinfer/jit/gdn.py +++ b/flashinfer/jit/gdn.py @@ -14,24 +14,70 @@ limitations under the License. """ +import itertools +import os + +import jinja2 + from . import env as jit_env from .core import ( JitSpec, gen_jit_spec, sm90a_nvcc_flags, ) +from .utils import write_if_different def gen_gdn_prefill_sm90_module() -> JitSpec: + """Generate JIT module for GDN prefill kernel with separate compilation. + + This generates 32 separate kernel instantiation files (2 dtypes × 16 boolean combinations) + plus the original launcher file. The separate files enable parallel compilation by ninja, + significantly reducing build time on multi-core machines. + """ + uri = "gdn_prefill_sm90" + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri + os.makedirs(gen_directory, exist_ok=True) + + source_paths = [] + + # Load kernel instantiation template + with open(jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_sm90_kernel_inst.jinja") as f: + kernel_inst_templ = jinja2.Template(f.read()) + + # Generate 32 separate instance files (2 dtypes × 16 boolean combinations) + dtypes = [("half", "half"), ("bf16", "nv_bfloat16")] + for dtype_name, dtype in dtypes: + for is_gva, needs_beta, needs_alpha, init_state in itertools.product( + [False, True], repeat=4 + ): + suffix = f"{dtype_name}_g{int(is_gva)}b{int(needs_beta)}a{int(needs_alpha)}i{int(init_state)}" + filename = f"gdn_prefill_kernel_{suffix}.cu" + dest_path = gen_directory / filename + source_paths.append(dest_path) + + source = kernel_inst_templ.render( + dtype=dtype, + is_gva=str(is_gva).lower(), + needs_beta=str(needs_beta).lower(), + needs_alpha=str(needs_alpha).lower(), + init_state=str(init_state).lower(), + ) + write_if_different(dest_path, source) + + # Copy source files to gen_directory for compilation + # Headers are now in include/flashinfer/flat/ and accessible via standard include paths + for filename in [ + "gdn_prefill_launcher.cu", + "prefill_kernel_delta_rule_sm90.cu", + ]: + src_path = jit_env.FLASHINFER_CSRC_DIR / filename + dest_path = gen_directory / src_path.name + source_paths.append(dest_path) + write_if_different(dest_path, src_path.read_text()) + return gen_jit_spec( - name="gdn_prefill_launcher", - sources=[ - jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_launcher.cu", - jit_env.FLASHINFER_CSRC_DIR - / "flat" - / "prefill" - / "prefill_kernel_delta_rule_sm90.cu", - ], + uri, + source_paths, extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"], - extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR], ) diff --git a/csrc/flat/ampere/collective/flat_collective_inverse.hpp b/include/flashinfer/flat/ampere/collective/flat_collective_inverse.hpp similarity index 99% rename from csrc/flat/ampere/collective/flat_collective_inverse.hpp rename to include/flashinfer/flat/ampere/collective/flat_collective_inverse.hpp index c9535c6396..b36f410b1f 100644 --- a/csrc/flat/ampere/collective/flat_collective_inverse.hpp +++ b/include/flashinfer/flat/ampere/collective/flat_collective_inverse.hpp @@ -18,7 +18,7 @@ #include "cute/tensor.hpp" #include "cutlass/arch/barrier.h" #include "cutlass/cutlass.h" -#include "flat/cute_ext.hpp" +#include "flashinfer/flat/cute_ext.hpp" namespace flat::collective { diff --git a/csrc/flat/ampere/collective/flat_collective_load.hpp b/include/flashinfer/flat/ampere/collective/flat_collective_load.hpp similarity index 99% rename from csrc/flat/ampere/collective/flat_collective_load.hpp rename to include/flashinfer/flat/ampere/collective/flat_collective_load.hpp index 3a7f517eff..ae83ddc7c5 100644 --- a/csrc/flat/ampere/collective/flat_collective_load.hpp +++ b/include/flashinfer/flat/ampere/collective/flat_collective_load.hpp @@ -18,7 +18,7 @@ #include "cute/tensor.hpp" #include "cutlass/cutlass.h" #include "cutlass/pipeline/sm90_pipeline.hpp" -#include "flat/unused.hpp" +#include "flashinfer/flat/unused.hpp" namespace flat::collective { diff --git a/csrc/flat/common.hpp b/include/flashinfer/flat/common.hpp similarity index 98% rename from csrc/flat/common.hpp rename to include/flashinfer/flat/common.hpp index 91939e6085..b3579ee982 100644 --- a/csrc/flat/common.hpp +++ b/include/flashinfer/flat/common.hpp @@ -19,7 +19,7 @@ #include #include -#include "debug.hpp" +#include "flashinfer/flat/debug.hpp" #define FLAT_UNUSED_PARAMETER(x) (void)x diff --git a/csrc/flat/cute_ext.hpp b/include/flashinfer/flat/cute_ext.hpp similarity index 100% rename from csrc/flat/cute_ext.hpp rename to include/flashinfer/flat/cute_ext.hpp diff --git a/csrc/flat/debug.hpp b/include/flashinfer/flat/debug.hpp similarity index 100% rename from csrc/flat/debug.hpp rename to include/flashinfer/flat/debug.hpp diff --git a/csrc/flat/hopper/collective/flat_collective_load.hpp b/include/flashinfer/flat/hopper/collective/flat_collective_load.hpp similarity index 100% rename from csrc/flat/hopper/collective/flat_collective_load.hpp rename to include/flashinfer/flat/hopper/collective/flat_collective_load.hpp diff --git a/csrc/flat/hopper/collective/flat_collective_store.hpp b/include/flashinfer/flat/hopper/collective/flat_collective_store.hpp similarity index 99% rename from csrc/flat/hopper/collective/flat_collective_store.hpp rename to include/flashinfer/flat/hopper/collective/flat_collective_store.hpp index 8cca5b4fba..1d16b19f6d 100644 --- a/csrc/flat/hopper/collective/flat_collective_store.hpp +++ b/include/flashinfer/flat/hopper/collective/flat_collective_store.hpp @@ -20,7 +20,7 @@ #include "cute/tensor.hpp" #include "cutlass/cutlass.h" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "flat/cute_ext.hpp" +#include "flashinfer/flat/cute_ext.hpp" namespace flat::collective { diff --git a/csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp b/include/flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp similarity index 98% rename from csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp rename to include/flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp index 49f499511a..47e0ef8cd2 100644 --- a/csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp +++ b/include/flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp @@ -15,18 +15,18 @@ */ #pragma once -#include "../../cute_ext.hpp" -#include "../../math_order_barrier.hpp" -#include "../../unused.hpp" -#include "../collective/flat_collective_load.hpp" -#include "../collective/flat_collective_store.hpp" -#include "../collective/flat_common.hpp" -#include "../collective/flat_named_barriers.hpp" -#include "../kernel/flat_options.hpp" #include "cutlass/cutlass.h" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "flat/ampere/collective/flat_collective_inverse.hpp" -#include "flat/ampere/collective/flat_collective_load.hpp" +#include "flashinfer/flat/ampere/collective/flat_collective_inverse.hpp" +#include "flashinfer/flat/ampere/collective/flat_collective_load.hpp" +#include "flashinfer/flat/cute_ext.hpp" +#include "flashinfer/flat/hopper/collective/flat_collective_load.hpp" +#include "flashinfer/flat/hopper/collective/flat_collective_store.hpp" +#include "flashinfer/flat/hopper/collective/flat_common.hpp" +#include "flashinfer/flat/hopper/collective/flat_named_barriers.hpp" +#include "flashinfer/flat/hopper/kernel/flat_options.hpp" +#include "flashinfer/flat/math_order_barrier.hpp" +#include "flashinfer/flat/unused.hpp" // #define INLINE_LAMBDA [[gnu::always_inline]] #define INLINE_LAMBDA __attribute__((always_inline)) diff --git a/csrc/flat/hopper/collective/flat_common.hpp b/include/flashinfer/flat/hopper/collective/flat_common.hpp similarity index 100% rename from csrc/flat/hopper/collective/flat_common.hpp rename to include/flashinfer/flat/hopper/collective/flat_common.hpp diff --git a/csrc/flat/hopper/collective/flat_named_barriers.hpp b/include/flashinfer/flat/hopper/collective/flat_named_barriers.hpp similarity index 100% rename from csrc/flat/hopper/collective/flat_named_barriers.hpp rename to include/flashinfer/flat/hopper/collective/flat_named_barriers.hpp diff --git a/csrc/flat/hopper/device/device_universal.hpp b/include/flashinfer/flat/hopper/device/device_universal.hpp similarity index 100% rename from csrc/flat/hopper/device/device_universal.hpp rename to include/flashinfer/flat/hopper/device/device_universal.hpp diff --git a/csrc/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp b/include/flashinfer/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp similarity index 87% rename from csrc/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp rename to include/flashinfer/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp index 60f920c3b1..6b02132c7d 100644 --- a/csrc/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp +++ b/include/flashinfer/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp @@ -15,11 +15,11 @@ */ #pragma once -#include "../collective/flat_collective_tma_warpspecialized_delta_rule.hpp" -#include "../kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp" -#include "../kernel/flat_options.hpp" -#include "../kernel/flat_tile_scheduler.hpp" -#include "flat/type_traits.hpp" +#include "flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp" +#include "flashinfer/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp" +#include "flashinfer/flat/hopper/kernel/flat_options.hpp" +#include "flashinfer/flat/hopper/kernel/flat_tile_scheduler.hpp" +#include "flashinfer/flat/type_traits.hpp" namespace flat::kernel { diff --git a/csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp b/include/flashinfer/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp similarity index 99% rename from csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp rename to include/flashinfer/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp index d332328382..de82ce7db8 100644 --- a/csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp +++ b/include/flashinfer/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp @@ -15,13 +15,13 @@ */ #pragma once -#include "../kernel/flat_options.hpp" #include "cutlass/arch/arch.h" #include "cutlass/arch/reg_reconfig.h" #include "cutlass/cutlass.h" #include "cutlass/pipeline/pipeline.hpp" -#include "flat/common.hpp" -#include "flat/unused.hpp" +#include "flashinfer/flat/common.hpp" +#include "flashinfer/flat/hopper/kernel/flat_options.hpp" +#include "flashinfer/flat/unused.hpp" namespace flat::kernel { diff --git a/csrc/flat/hopper/kernel/flat_options.hpp b/include/flashinfer/flat/hopper/kernel/flat_options.hpp similarity index 100% rename from csrc/flat/hopper/kernel/flat_options.hpp rename to include/flashinfer/flat/hopper/kernel/flat_options.hpp diff --git a/csrc/flat/hopper/kernel/flat_tile_scheduler.hpp b/include/flashinfer/flat/hopper/kernel/flat_tile_scheduler.hpp similarity index 100% rename from csrc/flat/hopper/kernel/flat_tile_scheduler.hpp rename to include/flashinfer/flat/hopper/kernel/flat_tile_scheduler.hpp diff --git a/csrc/flat/math.hpp b/include/flashinfer/flat/math.hpp similarity index 100% rename from csrc/flat/math.hpp rename to include/flashinfer/flat/math.hpp diff --git a/csrc/flat/math_order_barrier.hpp b/include/flashinfer/flat/math_order_barrier.hpp similarity index 100% rename from csrc/flat/math_order_barrier.hpp rename to include/flashinfer/flat/math_order_barrier.hpp diff --git a/csrc/flat/prefill/prefill_kernel.hpp b/include/flashinfer/flat/prefill/prefill_kernel.hpp similarity index 100% rename from csrc/flat/prefill/prefill_kernel.hpp rename to include/flashinfer/flat/prefill/prefill_kernel.hpp diff --git a/csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh b/include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm90.cuh similarity index 97% rename from csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh rename to include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm90.cuh index 3eb382d817..a34101f654 100644 --- a/csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh +++ b/include/flashinfer/flat/prefill/prefill_kernel_delta_rule_sm90.cuh @@ -21,9 +21,9 @@ #include "cutlass/cutlass.h" #include "cutlass/kernel_hardware_info.h" #include "cutlass/util/device_memory.h" -#include "flat/common.hpp" -#include "flat/hopper/device/device_universal.hpp" -#include "flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp" +#include "flashinfer/flat/common.hpp" +#include "flashinfer/flat/hopper/device/device_universal.hpp" +#include "flashinfer/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp" namespace flat { diff --git a/csrc/flat/type_traits.hpp b/include/flashinfer/flat/type_traits.hpp similarity index 100% rename from csrc/flat/type_traits.hpp rename to include/flashinfer/flat/type_traits.hpp diff --git a/csrc/flat/unused.hpp b/include/flashinfer/flat/unused.hpp similarity index 100% rename from csrc/flat/unused.hpp rename to include/flashinfer/flat/unused.hpp diff --git a/tests/gdn/test_prefill_delta_rule.py b/tests/gdn/test_prefill_delta_rule.py index f1211cd7a9..f2fd06cbce 100644 --- a/tests/gdn/test_prefill_delta_rule.py +++ b/tests/gdn/test_prefill_delta_rule.py @@ -106,7 +106,7 @@ def _test_prefill_kernel( torch.cuda.synchronize() - # postprocessing raw output, ref_state is v-major, our_state is k-major, unify to v-major for testing + # postprocessing raw output: ref_state is v-last [H,K,V], our_state is k-last [H,V,K], transpose to match our_state = our_state.transpose(-1, -2) ref_o, ref_state = blockwise_delta_rule( @@ -330,7 +330,7 @@ def _test_chunked_prefill( torch.cuda.synchronize() - # postprocessing raw output, ref_state is v-major, our_state is k-major, unify to v-major for testing + # postprocessing raw output: ref_state is v-last [H,K,V], our_state is k-last [H,V,K], transpose to match our_state = our_state.transpose(-1, -2) def concat_varlen(t1, cu_seq_lens1, t2, cu_seq_lens2):