-
Notifications
You must be signed in to change notification settings - Fork 829
refactor: reduce hopper's gdn prefill compilation time and fix docstring. #2422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9039bf0
e803d94
1879112
501034e
413cc5f
a7cafc0
2491b2b
bfe6991
ed77def
2aa8127
3e313ad
1322d60
e5236d7
d7499c8
37c254c
c74181f
a5127e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| /* | ||
| * Copyright (c) 2025 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| // Extern template declarations to prevent implicit instantiation in the dispatcher. | ||
| // Explicit instantiations are in separate generated files for parallel compilation. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
| #include "cutlass/arch/arch.h" | ||
|
|
||
| namespace flat { | ||
|
|
||
| // clang-format off | ||
|
|
||
| #define FOR_EACH_BOOL_4(MACRO, ...) \ | ||
| MACRO(false, false, false, false, __VA_ARGS__) \ | ||
| MACRO(false, false, false, true, __VA_ARGS__) \ | ||
| MACRO(false, false, true, false, __VA_ARGS__) \ | ||
| MACRO(false, false, true, true, __VA_ARGS__) \ | ||
| MACRO(false, true, false, false, __VA_ARGS__) \ | ||
| MACRO(false, true, false, true, __VA_ARGS__) \ | ||
| MACRO(false, true, true, false, __VA_ARGS__) \ | ||
| MACRO(false, true, true, true, __VA_ARGS__) \ | ||
| MACRO(true, false, false, false, __VA_ARGS__) \ | ||
| MACRO(true, false, false, true, __VA_ARGS__) \ | ||
| MACRO(true, false, true, false, __VA_ARGS__) \ | ||
| MACRO(true, false, true, true, __VA_ARGS__) \ | ||
| MACRO(true, true, false, false, __VA_ARGS__) \ | ||
| MACRO(true, true, false, true, __VA_ARGS__) \ | ||
| MACRO(true, true, true, false, __VA_ARGS__) \ | ||
| MACRO(true, true, true, true, __VA_ARGS__) | ||
|
|
||
| #define DECLARE_TEMPLATE_INSTANCE(is_gva, needs_beta, needs_alpha, init_state, ctype) \ | ||
| extern template void launch_delta_rule_prefill_kernel_gbai<is_gva, needs_beta, needs_alpha, init_state, cutlass::arch::Sm90, ctype, ctype, float>( \ | ||
| cudaStream_t, ctype*, float*, ctype const*, ctype const*, ctype const*, \ | ||
| float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, \ | ||
| int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t); | ||
|
|
||
| // Extern template declarations for half | ||
| FOR_EACH_BOOL_4(DECLARE_TEMPLATE_INSTANCE, half) | ||
|
|
||
| // Extern template declarations for nv_bfloat16 | ||
| FOR_EACH_BOOL_4(DECLARE_TEMPLATE_INSTANCE, nv_bfloat16) | ||
|
|
||
| #undef DECLARE_TEMPLATE_INSTANCE | ||
| #undef FOR_EACH_BOOL_4 | ||
|
|
||
| // clang-format on | ||
|
|
||
| } // namespace flat | ||
yzh119 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| /* | ||
| * Copyright (c) 2025 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| // Auto-generated file for separate compilation of GDN prefill kernel variants. | ||
| // Template parameters: dtype={{ dtype }}, is_gva={{ is_gva }}, needs_beta={{ needs_beta }}, | ||
| // needs_alpha={{ needs_alpha }}, init_state={{ init_state }} | ||
|
|
||
| // CUDA type definitions for half and nv_bfloat16 | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
|
|
||
| // Include the header which defines the function template | ||
| // The header includes all necessary CUTLASS type definitions | ||
| #include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm90.cuh" | ||
|
|
||
| namespace flat { | ||
|
|
||
| // Explicit template instantiation for launch_delta_rule_prefill_kernel_gbai | ||
| // Parameter types must exactly match the extern template declaration in prefill_kernel_delta_rule_sm90_extern.inc | ||
| template void launch_delta_rule_prefill_kernel_gbai<{{ is_gva }}, {{ needs_beta }}, {{ needs_alpha }}, {{ init_state }}, cutlass::arch::Sm90, {{ dtype }}, {{ dtype }}, float>( | ||
| cudaStream_t, {{ dtype }}*, float*, {{ dtype }} const*, {{ dtype }} const*, {{ dtype }} const*, | ||
| float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, | ||
| int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t); | ||
|
|
||
| } // namespace flat |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,24 +14,70 @@ | |
| limitations under the License. | ||
| """ | ||
|
|
||
| import itertools | ||
| import os | ||
|
|
||
| import jinja2 | ||
|
|
||
| from . import env as jit_env | ||
| from .core import ( | ||
| JitSpec, | ||
| gen_jit_spec, | ||
| sm90a_nvcc_flags, | ||
| ) | ||
| from .utils import write_if_different | ||
|
|
||
|
|
||
| def gen_gdn_prefill_sm90_module() -> JitSpec: | ||
| """Generate JIT module for GDN prefill kernel with separate compilation. | ||
|
|
||
| This generates 32 separate kernel instantiation files (2 dtypes Γ 16 boolean combinations) | ||
| plus the original launcher file. The separate files enable parallel compilation by ninja, | ||
| significantly reducing build time on multi-core machines. | ||
| """ | ||
| uri = "gdn_prefill_sm90" | ||
| gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri | ||
| os.makedirs(gen_directory, exist_ok=True) | ||
|
|
||
| source_paths = [] | ||
|
|
||
| # Load kernel instantiation template | ||
| with open(jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_sm90_kernel_inst.jinja") as f: | ||
| kernel_inst_templ = jinja2.Template(f.read()) | ||
|
|
||
| # Generate 32 separate instance files (2 dtypes Γ 16 boolean combinations) | ||
yzh119 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| dtypes = [("half", "half"), ("bf16", "nv_bfloat16")] | ||
| for dtype_name, dtype in dtypes: | ||
| for is_gva, needs_beta, needs_alpha, init_state in itertools.product( | ||
| [False, True], repeat=4 | ||
| ): | ||
| suffix = f"{dtype_name}_g{int(is_gva)}b{int(needs_beta)}a{int(needs_alpha)}i{int(init_state)}" | ||
| filename = f"gdn_prefill_kernel_{suffix}.cu" | ||
| dest_path = gen_directory / filename | ||
| source_paths.append(dest_path) | ||
|
|
||
| source = kernel_inst_templ.render( | ||
| dtype=dtype, | ||
| is_gva=str(is_gva).lower(), | ||
| needs_beta=str(needs_beta).lower(), | ||
| needs_alpha=str(needs_alpha).lower(), | ||
| init_state=str(init_state).lower(), | ||
| ) | ||
| write_if_different(dest_path, source) | ||
|
|
||
| # Copy source files to gen_directory for compilation | ||
| # Headers are now in include/flashinfer/flat/ and accessible via standard include paths | ||
| for filename in [ | ||
| "gdn_prefill_launcher.cu", | ||
| "prefill_kernel_delta_rule_sm90.cu", | ||
| ]: | ||
| src_path = jit_env.FLASHINFER_CSRC_DIR / filename | ||
| dest_path = gen_directory / src_path.name | ||
| source_paths.append(dest_path) | ||
| write_if_different(dest_path, src_path.read_text()) | ||
|
|
||
| return gen_jit_spec( | ||
| name="gdn_prefill_launcher", | ||
| sources=[ | ||
| jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_launcher.cu", | ||
| jit_env.FLASHINFER_CSRC_DIR | ||
| / "flat" | ||
| / "prefill" | ||
| / "prefill_kernel_delta_rule_sm90.cu", | ||
| ], | ||
| uri, | ||
| source_paths, | ||
| extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"], | ||
| extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR], | ||
| ) | ||
|
Comment on lines
79
to
83
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. π§© Analysis chainπ Script executed: # Check gen_jit_spec function signature and supported_major_versions usage
rg -n "supported_major_versions" --type py -B 2 -A 2Repository: flashinfer-ai/flashinfer Length of output: 7679 π Script executed: # Find gen_jit_spec definition
rg -n "def gen_jit_spec" --type py -A 10Repository: flashinfer-ai/flashinfer Length of output: 829 π Script executed: # Check for other JIT module generators to see usage patterns
fd "\.py$" flashinfer/jit/ | head -10 | xargs grep -l "gen_jit_spec" | xargs -I {} sh -c 'echo "=== {} ===" && grep -A 5 "gen_jit_spec(" {}'Repository: flashinfer-ai/flashinfer Length of output: 4164 π Script executed: cat -n flashinfer/jit/gdn.pyRepository: flashinfer-ai/flashinfer Length of output: 3736 π Script executed: grep -n "sm90a_nvcc_flags" flashinfer/jit/core.py | head -20Repository: flashinfer-ai/flashinfer Length of output: 92 π Script executed: # Get context around the definition
rg -n "sm90a_nvcc_flags\s*=" flashinfer/jit/core.py -B 2 -A 2Repository: flashinfer-ai/flashinfer Length of output: 238 π Script executed: sed -n '120,135p' flashinfer/jit/core.pyRepository: flashinfer-ai/flashinfer Length of output: 810 π Script executed: # Check how other SM90-specific generators handle the architecture restriction
rg -n "SM90\|sm_90" flashinfer/jit/ --type py -B 3 -A 3 | grep -E "(supported_major_versions|sm90)" | head -30Repository: flashinfer-ai/flashinfer Length of output: 50 π Script executed: head -30 flashinfer/jit/gdn.pyRepository: flashinfer-ai/flashinfer Length of output: 839 π Script executed: # Check if CompilationContext is available and how it's used
rg -n "CompilationContext\|current_compilation_context" flashinfer/jit/ --type py | head -20Repository: flashinfer-ai/flashinfer Length of output: 50 π Script executed: # Check what common_nvcc_flags contains
sed -n '1,130p' flashinfer/jit/core.py | grep -A 10 "common_nvcc_flags"Repository: flashinfer-ai/flashinfer Length of output: 567 Refactor to use Per coding guidelines for JIT modules, compilation_context = CompilationContext()
nvcc_flags = compilation_context.get_nvcc_flags_list(supported_major_versions=[9])
nvcc_flags += ["-DFLAT_SM90A_ENABLED", "-std=c++20"]
return gen_jit_spec(
uri,
source_paths,
extra_cuda_cflags=nvcc_flags,
extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR],
)π§° Toolsπͺ Ruff (0.14.14)84-84: Consider iterable unpacking instead of concatenation Replace with iterable unpacking (RUF005) π€ Prompt for AI Agents |
||
Uh oh!
There was an error while loading. Please reload this page.