Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion flashinfer/jit/gemm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def gen_gemm_sm100_module_cutlass_fp4() -> JitSpec:
+ [
"-DENABLE_BF16",
"-DENABLE_FP4",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
],
Comment on lines 91 to 96
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability and reduce code duplication, consider defining these new GDC (Grid Dependency Control) compile flags as a constant list at the module level. This constant can then be reused across all the JIT spec generation functions that require these flags, as this pattern is repeated in 5 other places in this file.

For example, you could define the following at the top of flashinfer/jit/gemm/core.py:

GDC_COMPILE_FLAGS = [
    "-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
    "-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
]

And then use it as suggested below.

        + [
            "-DENABLE_BF16",
            "-DENABLE_FP4",
        ] + GDC_COMPILE_FLAGS

extra_cflags=[
"-DFAST_BUILD",
Expand Down Expand Up @@ -158,6 +160,8 @@ def gen_gemm_sm103_module_cutlass_fp4() -> JitSpec:
+ [
"-DENABLE_BF16",
"-DENABLE_FP4",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
],
Comment on lines 160 to 165
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid code duplication, please use the GDC_COMPILE_FLAGS constant as suggested in the previous comment.

        + [
            "-DENABLE_BF16",
            "-DENABLE_FP4",
        ] + GDC_COMPILE_FLAGS

extra_cflags=[
"-DFAST_BUILD",
Expand Down Expand Up @@ -206,6 +210,8 @@ def gen_gemm_sm120_module_cutlass_fp4() -> JitSpec:
+ [
"-DENABLE_BF16",
"-DENABLE_FP4",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
],
Comment on lines 210 to 215
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid code duplication, please use the GDC_COMPILE_FLAGS constant as suggested.

        + [
            "-DENABLE_BF16",
            "-DENABLE_FP4",
        ] + GDC_COMPILE_FLAGS

extra_cflags=[
"-DFAST_BUILD",
Expand Down Expand Up @@ -256,6 +262,8 @@ def gen_gemm_sm100_module_cutlass_fp8() -> JitSpec:
extra_cuda_cflags=nvcc_flags
+ [
"-DENABLE_BF16",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
],
Comment on lines 263 to 267
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid code duplication, please use the GDC_COMPILE_FLAGS constant as suggested.

        + [
            "-DENABLE_BF16",
        ] + GDC_COMPILE_FLAGS

extra_cflags=[
"-DFAST_BUILD",
Expand Down Expand Up @@ -349,6 +357,8 @@ def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec:
extra_cuda_cflags=nvcc_flags
+ [
"-DENABLE_BF16",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
],
Comment on lines 358 to 362
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid code duplication, please use the GDC_COMPILE_FLAGS constant as suggested.

        + [
            "-DENABLE_BF16",
        ] + GDC_COMPILE_FLAGS

extra_cflags=[
"-DFAST_BUILD",
Expand Down Expand Up @@ -516,7 +526,11 @@ def gen_gemm_sm120_module() -> JitSpec:
return gen_jit_spec(
"gemm_sm120",
source_paths,
extra_cuda_cflags=nvcc_flags,
extra_cuda_cflags=nvcc_flags
+ [
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
],
Comment on lines +529 to +533
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid code duplication, please use the GDC_COMPILE_FLAGS constant as suggested.

        extra_cuda_cflags=nvcc_flags + GDC_COMPILE_FLAGS

)


Expand Down
Loading