diff --git a/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/grid_dependency_control.h b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/grid_dependency_control.h index 5f71a94886..2257fbbc5c 100644 --- a/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/grid_dependency_control.h +++ b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/grid_dependency_control.h @@ -33,6 +33,24 @@ #endif #endif +#ifndef CUTLASS_GDC_ENABLED +#if (CUDA_BARRIER_ENABLED && defined(CUTLASS_ENABLE_GDC_FOR_SM100) && defined(__CUDA_ARCH__) && \ + ((__CUDA_ARCH__ == 1000 && \ + (defined(__CUDA_ARCH_FEAT_SM100_ALL) || CUDA_ARCH_FAMILY(1000))) || \ + (__CUDA_ARCH__ == 1010 && \ + (defined(__CUDA_ARCH_FEAT_SM101_ALL) || CUDA_ARCH_FAMILY(1010))) || \ + (__CUDA_ARCH__ == 1100 && \ + (defined(__CUDA_ARCH_FEAT_SM110_ALL) || CUDA_ARCH_FAMILY(1100))) || \ + (__CUDA_ARCH__ == 1030 && \ + (defined(__CUDA_ARCH_FEAT_SM103_ALL) || CUDA_ARCH_FAMILY(1030))) || \ + (__CUDA_ARCH__ == 1200 && \ + (defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_FAMILY(1200))) || \ + (__CUDA_ARCH__ == 1210 && \ + (defined(__CUDA_ARCH_FEAT_SM121_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210))))) +#define CUTLASS_GDC_ENABLED +#endif +#endif + namespace cutlass { namespace arch { diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index c196bc5f8b..e5e864c2c7 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -38,6 +38,7 @@ def gen_cutlass_fused_moe_sm120_module(use_fast_build: bool = False) -> JitSpec: "-DENABLE_FP8", "-DENABLE_FP4", "-DUSING_OSS_CUTLASS_MOE_GEMM", + "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", ] nvcc_flags += current_compilation_context.get_nvcc_flags_list( @@ -56,6 +57,7 @@ def gen_cutlass_fused_moe_sm103_module(use_fast_build: bool = False) -> JitSpec: "-DENABLE_FP4", "-DUSING_OSS_CUTLASS_MOE_GEMM", "-DCOMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS", + "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", ] nvcc_flags += current_compilation_context.get_nvcc_flags_list( @@ -73,6 +75,7 @@ def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec: "-DENABLE_FP8", "-DENABLE_FP4", "-DUSING_OSS_CUTLASS_MOE_GEMM", + "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", ] nvcc_flags += current_compilation_context.get_nvcc_flags_list( @@ -91,6 +94,7 @@ def gen_cutlass_fused_moe_sm90_module(use_fast_build: bool = False) -> JitSpec: "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "", "-DENABLE_FP4" if is_cuda_version_at_least("12.8") else "", "-DUSING_OSS_CUTLASS_MOE_GEMM", + "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", ] return gen_cutlass_fused_moe_module(nvcc_flags, "90", use_fast_build) @@ -304,6 +308,7 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec: "-DENABLE_BF16", "-DENABLE_FP8", "-DENABLE_FP4", + "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"', ] + nvcc_flags, diff --git a/flashinfer/jit/gemm/fp8_blockscale.py b/flashinfer/jit/gemm/fp8_blockscale.py index e55b0b7c0b..e2290fa650 100755 --- a/flashinfer/jit/gemm/fp8_blockscale.py +++ b/flashinfer/jit/gemm/fp8_blockscale.py @@ -14,6 +14,7 @@ def gen_fp8_blockscale_gemm_sm90_module(use_fast_build: bool = False) -> JitSpec "-DENABLE_BF16", "-DENABLE_FP8", *(("-DENABLE_FP8_BLOCK_SCALE",) if is_cuda_version_at_least("12.8") else ()), + "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", ] return gen_jit_spec(