diff --git a/flashinfer/aot.py b/flashinfer/aot.py index f11ac238bb..d7837593d5 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -512,11 +512,15 @@ def gen_all_modules( jit_specs.append(gen_fp4_quantization_sm110_module()) if has_sm120: jit_specs.append(gen_fp4_quantization_sm120_module()) + if has_sm121: + jit_specs.append(gen_fp4_quantization_sm121_module()) + if has_sm120 or has_sm121: + # SM120 and SM121 share the same CUTLASS kernels for fused MOE and GEMM. + # The SM120 module generators use supported_major_versions=[12] which + # compiles for all SM12x targets. jit_specs.append(gen_cutlass_fused_moe_sm120_module()) jit_specs.append(gen_gemm_sm120_module()) jit_specs.append(gen_gemm_sm120_module_cutlass_fp4()) - if has_sm121: - jit_specs.append(gen_fp4_quantization_sm121_module()) if add_comm: from .jit.comm import (