Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions csrc/flat_prefill_kernel_delta_rule_sm90_extern.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Extern template declarations to prevent implicit instantiation in the dispatcher.
// Explicit instantiations are in separate generated files for parallel compilation.

#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "cutlass/arch/arch.h"

namespace flat {

// clang-format off

#define FOR_EACH_BOOL_4(MACRO, ...) \
MACRO(false, false, false, false, __VA_ARGS__) \
MACRO(false, false, false, true, __VA_ARGS__) \
MACRO(false, false, true, false, __VA_ARGS__) \
MACRO(false, false, true, true, __VA_ARGS__) \
MACRO(false, true, false, false, __VA_ARGS__) \
MACRO(false, true, false, true, __VA_ARGS__) \
MACRO(false, true, true, false, __VA_ARGS__) \
MACRO(false, true, true, true, __VA_ARGS__) \
MACRO(true, false, false, false, __VA_ARGS__) \
MACRO(true, false, false, true, __VA_ARGS__) \
MACRO(true, false, true, false, __VA_ARGS__) \
MACRO(true, false, true, true, __VA_ARGS__) \
MACRO(true, true, false, false, __VA_ARGS__) \
MACRO(true, true, false, true, __VA_ARGS__) \
MACRO(true, true, true, false, __VA_ARGS__) \
MACRO(true, true, true, true, __VA_ARGS__)

#define DECLARE_TEMPLATE_INSTANCE(is_gva, needs_beta, needs_alpha, init_state, ctype) \
extern template void launch_delta_rule_prefill_kernel_gbai<is_gva, needs_beta, needs_alpha, init_state, cutlass::arch::Sm90, ctype, ctype, float>( \
cudaStream_t, ctype*, float*, ctype const*, ctype const*, ctype const*, \
float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, \
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t);

// Extern template declarations for half
FOR_EACH_BOOL_4(DECLARE_TEMPLATE_INSTANCE, half)

// Extern template declarations for nv_bfloat16
FOR_EACH_BOOL_4(DECLARE_TEMPLATE_INSTANCE, nv_bfloat16)

#undef DECLARE_TEMPLATE_INSTANCE
#undef FOR_EACH_BOOL_4

// clang-format on

} // namespace flat
2 changes: 1 addition & 1 deletion csrc/gdn_prefill_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include <iostream>
#include <sstream>

#include "flat/prefill/prefill_kernel.hpp"
#include "flashinfer/flat/prefill/prefill_kernel.hpp"

using tvm::ffi::Optional;
using tvm::ffi::TensorView;
Expand Down
37 changes: 37 additions & 0 deletions csrc/gdn_prefill_sm90_kernel_inst.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Auto-generated file for separate compilation of GDN prefill kernel variants.
// Template parameters: dtype={{ dtype }}, is_gva={{ is_gva }}, needs_beta={{ needs_beta }},
// needs_alpha={{ needs_alpha }}, init_state={{ init_state }}

// CUDA type definitions for half and nv_bfloat16
#include <cuda_bf16.h>
#include <cuda_fp16.h>

// Include the header which defines the function template
// The header includes all necessary CUTLASS type definitions
#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm90.cuh"

namespace flat {

// Explicit template instantiation for launch_delta_rule_prefill_kernel_gbai
// Parameter types must exactly match the extern template declaration in prefill_kernel_delta_rule_sm90_extern.inc
template void launch_delta_rule_prefill_kernel_gbai<{{ is_gva }}, {{ needs_beta }}, {{ needs_alpha }}, {{ init_state }}, cutlass::arch::Sm90, {{ dtype }}, {{ dtype }}, float>(
cudaStream_t, {{ dtype }}*, float*, {{ dtype }} const*, {{ dtype }} const*, {{ dtype }} const*,
float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t,
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t);

} // namespace flat
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
*/
#include <cuda_bf16.h>

#include "prefill_kernel_delta_rule_sm90.cuh"
#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm90.cuh"

// Extern template declarations prevent implicit instantiation here.
// Explicit instantiations are in separate generated files for parallel compilation.
#include "flat_prefill_kernel_delta_rule_sm90_extern.inc"

namespace flat {

Expand Down Expand Up @@ -87,6 +91,8 @@ void launch_delta_rule_prefill_kernel(cudaStream_t stream, TO* output, TState* o
#undef LAUNCH
}

// Explicit instantiations for the outer dispatch function only.
// The inner launch_delta_rule_prefill_kernel_gbai instantiations are in separate files.
template void launch_delta_rule_prefill_kernel<cutlass::arch::Sm90, half, half, float>(
cudaStream_t stream, half* output, float* state, half const* q, half const* k, half const* v,
float const* input_state, float const* alpha, float const* beta, int64_t const* cu_seqlens,
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/gdn_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def chunk_gated_delta_rule(
Note:
- Supports GQA: ``num_q_heads > num_k_heads = num_v_heads``
- Supports GVA: ``num_v_heads > num_q_heads = num_k_heads``
- The final state is in k-major layout ``[N, H, K, V]``.
- The final state is in k-last layout ``[N, H, V, K]``.
- Requires SM90 (Hopper) architecture.
"""
assert cu_seqlens is not None, "cu_seqlens is required for varlen mode"
Expand Down
17 changes: 15 additions & 2 deletions flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,16 +417,29 @@ def gen_jit_spec(
verbose_env = os.environ.get("FLASHINFER_JIT_VERBOSE", "0")
debug = (debug_env if debug_env is not None else verbose_env) == "1"

cflags = ["-std=c++17", "-Wno-switch-bool"]
# Only add default C++ standard if not specified in extra flags
cflags_has_std = extra_cflags is not None and any(
f.startswith("-std=") for f in extra_cflags
)
cuda_cflags_has_std = extra_cuda_cflags is not None and any(
f.startswith("-std=") for f in extra_cuda_cflags
)

cflags = ["-Wno-switch-bool"]
if not cflags_has_std:
cflags.insert(0, "-std=c++17")

cuda_cflags = [
"-std=c++17",
f"--threads={os.environ.get('FLASHINFER_NVCC_THREADS', '1')}",
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
"-DFLASHINFER_ENABLE_BF16",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
if not cuda_cflags_has_std:
cuda_cflags.insert(0, "-std=c++17")

if debug:
cflags += ["-O0", "-g"]
cuda_cflags += [
Expand Down
64 changes: 55 additions & 9 deletions flashinfer/jit/gdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,70 @@
limitations under the License.
"""

import itertools
import os

import jinja2

from . import env as jit_env
from .core import (
JitSpec,
gen_jit_spec,
sm90a_nvcc_flags,
)
from .utils import write_if_different


def gen_gdn_prefill_sm90_module() -> JitSpec:
"""Generate JIT module for GDN prefill kernel with separate compilation.

This generates 32 separate kernel instantiation files (2 dtypes Γ— 16 boolean combinations)
plus the original launcher file. The separate files enable parallel compilation by ninja,
significantly reducing build time on multi-core machines.
"""
uri = "gdn_prefill_sm90"
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
os.makedirs(gen_directory, exist_ok=True)

source_paths = []

# Load kernel instantiation template
with open(jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_sm90_kernel_inst.jinja") as f:
kernel_inst_templ = jinja2.Template(f.read())

# Generate 32 separate instance files (2 dtypes Γ— 16 boolean combinations)
dtypes = [("half", "half"), ("bf16", "nv_bfloat16")]
for dtype_name, dtype in dtypes:
for is_gva, needs_beta, needs_alpha, init_state in itertools.product(
[False, True], repeat=4
):
suffix = f"{dtype_name}_g{int(is_gva)}b{int(needs_beta)}a{int(needs_alpha)}i{int(init_state)}"
filename = f"gdn_prefill_kernel_{suffix}.cu"
dest_path = gen_directory / filename
source_paths.append(dest_path)

source = kernel_inst_templ.render(
dtype=dtype,
is_gva=str(is_gva).lower(),
needs_beta=str(needs_beta).lower(),
needs_alpha=str(needs_alpha).lower(),
init_state=str(init_state).lower(),
)
write_if_different(dest_path, source)

# Copy source files to gen_directory for compilation
# Headers are now in include/flashinfer/flat/ and accessible via standard include paths
for filename in [
"gdn_prefill_launcher.cu",
"prefill_kernel_delta_rule_sm90.cu",
]:
src_path = jit_env.FLASHINFER_CSRC_DIR / filename
dest_path = gen_directory / src_path.name
source_paths.append(dest_path)
write_if_different(dest_path, src_path.read_text())

return gen_jit_spec(
name="gdn_prefill_launcher",
sources=[
jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_launcher.cu",
jit_env.FLASHINFER_CSRC_DIR
/ "flat"
/ "prefill"
/ "prefill_kernel_delta_rule_sm90.cu",
],
uri,
source_paths,
extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"],
extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR],
)
Comment on lines 79 to 83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Check gen_jit_spec function signature and supported_major_versions usage
rg -n "supported_major_versions" --type py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 7679


🏁 Script executed:

# Find gen_jit_spec definition
rg -n "def gen_jit_spec" --type py -A 10

Repository: flashinfer-ai/flashinfer

Length of output: 829


🏁 Script executed:

# Check for other JIT module generators to see usage patterns
fd "\.py$" flashinfer/jit/ | head -10 | xargs grep -l "gen_jit_spec" | xargs -I {} sh -c 'echo "=== {} ===" && grep -A 5 "gen_jit_spec(" {}'

Repository: flashinfer-ai/flashinfer

Length of output: 4164


🏁 Script executed:

cat -n flashinfer/jit/gdn.py

Repository: flashinfer-ai/flashinfer

Length of output: 3736


🏁 Script executed:

grep -n "sm90a_nvcc_flags" flashinfer/jit/core.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 92


🏁 Script executed:

# Get context around the definition
rg -n "sm90a_nvcc_flags\s*=" flashinfer/jit/core.py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 238


🏁 Script executed:

sed -n '120,135p' flashinfer/jit/core.py

Repository: flashinfer-ai/flashinfer

Length of output: 810


🏁 Script executed:

# Check how other SM90-specific generators handle the architecture restriction
rg -n "SM90\|sm_90" flashinfer/jit/ --type py -B 3 -A 3 | grep -E "(supported_major_versions|sm90)" | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

head -30 flashinfer/jit/gdn.py

Repository: flashinfer-ai/flashinfer

Length of output: 839


🏁 Script executed:

# Check if CompilationContext is available and how it's used
rg -n "CompilationContext\|current_compilation_context" flashinfer/jit/ --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check what common_nvcc_flags contains
sed -n '1,130p' flashinfer/jit/core.py | grep -A 10 "common_nvcc_flags"

Repository: flashinfer-ai/flashinfer

Length of output: 567


Refactor to use CompilationContext.get_nvcc_flags_list(supported_major_versions=[9]) pattern.

Per coding guidelines for JIT modules, supported_major_versions should be specified when creating nvcc flags. Other JIT modules (fused_moe.py, xqa.py, mamba, gemm, comm, attention) consistently use CompilationContext().get_nvcc_flags_list(supported_major_versions=[...]) before passing flags to gen_jit_spec(). This module should follow the same pattern instead of using the pre-defined sm90a_nvcc_flags constant:

compilation_context = CompilationContext()
nvcc_flags = compilation_context.get_nvcc_flags_list(supported_major_versions=[9])
nvcc_flags += ["-DFLAT_SM90A_ENABLED", "-std=c++20"]
return gen_jit_spec(
    uri,
    source_paths,
    extra_cuda_cflags=nvcc_flags,
    extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR],
)
🧰 Tools
πŸͺ› Ruff (0.14.14)

84-84: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

πŸ€– Prompt for AI Agents
In `@flashinfer/jit/gdn.py` around lines 81 - 86, Replace the hard-coded
sm90a_nvcc_flags usage with a CompilationContext-derived nvcc flags list: create
a CompilationContext(), call
CompilationContext.get_nvcc_flags_list(supported_major_versions=[9]) to get
nvcc_flags, append ["-DFLAT_SM90A_ENABLED","-std=c++20"], and pass that as
extra_cuda_cflags to gen_jit_spec (keep
extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR] unchanged); update references
to sm90a_nvcc_flags in this function to use the new nvcc_flags variable and
ensure CompressionContext is imported or available.

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "cute/tensor.hpp"
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
#include "flat/cute_ext.hpp"
#include "flashinfer/flat/cute_ext.hpp"

namespace flat::collective {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/pipeline/sm90_pipeline.hpp"
#include "flat/unused.hpp"
#include "flashinfer/flat/unused.hpp"

namespace flat::collective {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <stdexcept>
#include <string>

#include "debug.hpp"
#include "flashinfer/flat/debug.hpp"

#define FLAT_UNUSED_PARAMETER(x) (void)x

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "flat/cute_ext.hpp"
#include "flashinfer/flat/cute_ext.hpp"

namespace flat::collective {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
*/
#pragma once

#include "../../cute_ext.hpp"
#include "../../math_order_barrier.hpp"
#include "../../unused.hpp"
#include "../collective/flat_collective_load.hpp"
#include "../collective/flat_collective_store.hpp"
#include "../collective/flat_common.hpp"
#include "../collective/flat_named_barriers.hpp"
#include "../kernel/flat_options.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "flat/ampere/collective/flat_collective_inverse.hpp"
#include "flat/ampere/collective/flat_collective_load.hpp"
#include "flashinfer/flat/ampere/collective/flat_collective_inverse.hpp"
#include "flashinfer/flat/ampere/collective/flat_collective_load.hpp"
#include "flashinfer/flat/cute_ext.hpp"
#include "flashinfer/flat/hopper/collective/flat_collective_load.hpp"
#include "flashinfer/flat/hopper/collective/flat_collective_store.hpp"
#include "flashinfer/flat/hopper/collective/flat_common.hpp"
#include "flashinfer/flat/hopper/collective/flat_named_barriers.hpp"
#include "flashinfer/flat/hopper/kernel/flat_options.hpp"
#include "flashinfer/flat/math_order_barrier.hpp"
#include "flashinfer/flat/unused.hpp"

// #define INLINE_LAMBDA [[gnu::always_inline]]
#define INLINE_LAMBDA __attribute__((always_inline))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
*/
#pragma once

#include "../collective/flat_collective_tma_warpspecialized_delta_rule.hpp"
#include "../kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp"
#include "../kernel/flat_options.hpp"
#include "../kernel/flat_tile_scheduler.hpp"
#include "flat/type_traits.hpp"
#include "flashinfer/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp"
#include "flashinfer/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp"
#include "flashinfer/flat/hopper/kernel/flat_options.hpp"
#include "flashinfer/flat/hopper/kernel/flat_tile_scheduler.hpp"
#include "flashinfer/flat/type_traits.hpp"

namespace flat::kernel {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
*/
#pragma once

#include "../kernel/flat_options.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "flat/common.hpp"
#include "flat/unused.hpp"
#include "flashinfer/flat/common.hpp"
#include "flashinfer/flat/hopper/kernel/flat_options.hpp"
#include "flashinfer/flat/unused.hpp"

namespace flat::kernel {

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.h"
#include "cutlass/util/device_memory.h"
#include "flat/common.hpp"
#include "flat/hopper/device/device_universal.hpp"
#include "flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp"
#include "flashinfer/flat/common.hpp"
#include "flashinfer/flat/hopper/device/device_universal.hpp"
#include "flashinfer/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp"

namespace flat {

Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions tests/gdn/test_prefill_delta_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _test_prefill_kernel(

torch.cuda.synchronize()

# postprocessing raw output, ref_state is v-major, our_state is k-major, unify to v-major for testing
# postprocessing raw output: ref_state is v-last [H,K,V], our_state is k-last [H,V,K], transpose to match
our_state = our_state.transpose(-1, -2)

ref_o, ref_state = blockwise_delta_rule(
Expand Down Expand Up @@ -330,7 +330,7 @@ def _test_chunked_prefill(

torch.cuda.synchronize()

# postprocessing raw output, ref_state is v-major, our_state is k-major, unify to v-major for testing
# postprocessing raw output: ref_state is v-last [H,K,V], our_state is k-last [H,V,K], transpose to match
our_state = our_state.transpose(-1, -2)

def concat_varlen(t1, cu_seq_lens1, t2, cu_seq_lens2):
Expand Down
Loading