-
Notifications
You must be signed in to change notification settings - Fork 833
fix(jit): GEMM kernels produce NaN under concurrency — missing GDC flags cause PDL synchronization barriers to compile as no-ops #2716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -91,6 +91,8 @@ def gen_gemm_sm100_module_cutlass_fp4() -> JitSpec: | |
| + [ | ||
| "-DENABLE_BF16", | ||
| "-DENABLE_FP4", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", | ||
| ], | ||
| extra_cflags=[ | ||
| "-DFAST_BUILD", | ||
|
|
@@ -158,6 +160,8 @@ def gen_gemm_sm103_module_cutlass_fp4() -> JitSpec: | |
| + [ | ||
| "-DENABLE_BF16", | ||
| "-DENABLE_FP4", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", | ||
| ], | ||
|
Comment on lines
160
to
165
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| extra_cflags=[ | ||
| "-DFAST_BUILD", | ||
|
|
@@ -206,6 +210,8 @@ def gen_gemm_sm120_module_cutlass_fp4() -> JitSpec: | |
| + [ | ||
| "-DENABLE_BF16", | ||
| "-DENABLE_FP4", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", | ||
| ], | ||
|
Comment on lines
210
to
215
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| extra_cflags=[ | ||
| "-DFAST_BUILD", | ||
|
|
@@ -256,6 +262,8 @@ def gen_gemm_sm100_module_cutlass_fp8() -> JitSpec: | |
| extra_cuda_cflags=nvcc_flags | ||
| + [ | ||
| "-DENABLE_BF16", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", | ||
| ], | ||
|
Comment on lines
263
to
267
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| extra_cflags=[ | ||
| "-DFAST_BUILD", | ||
|
|
@@ -349,6 +357,8 @@ def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec: | |
| extra_cuda_cflags=nvcc_flags | ||
| + [ | ||
| "-DENABLE_BF16", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", | ||
| ], | ||
|
Comment on lines
358
to
362
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| extra_cflags=[ | ||
| "-DFAST_BUILD", | ||
|
|
@@ -516,7 +526,11 @@ def gen_gemm_sm120_module() -> JitSpec: | |
| return gen_jit_spec( | ||
| "gemm_sm120", | ||
| source_paths, | ||
| extra_cuda_cflags=nvcc_flags, | ||
| extra_cuda_cflags=nvcc_flags | ||
| + [ | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", | ||
| ], | ||
|
Comment on lines
+529
to
+533
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| ) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve maintainability and reduce code duplication, consider defining these new GDC (Grid Dependency Control) compile flags as a constant list at the module level. This constant can then be reused across all the JIT spec generation functions that require these flags, as this pattern is repeated in 5 other places in this file.
For example, you could define the following at the top of
flashinfer/jit/gemm/core.py:And then use it as suggested below.