diff --git a/flashinfer/aot.py b/flashinfer/aot.py index d2b23b7726..7afc4e54bb 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -44,6 +44,7 @@ gen_single_decode_module, gen_single_prefill_module, gen_trtllm_gen_fmha_module, + gen_trtllm_fmha_v2_sm120_module, ) from .jit.cascade import gen_cascade_module from .jit.cpp_ext import get_cuda_version @@ -522,13 +523,14 @@ def gen_all_modules( if has_sm121: jit_specs.append(gen_fp4_quantization_sm121_module()) if has_sm120 or has_sm121: - # SM120 and SM121 share the same CUTLASS kernels for fused MOE and GEMM. + # SM120 and SM121 share the same kernels for fused MOE, GEMM, and attention. # The SM120 module generators use supported_major_versions=[12] which # compiles for all SM12x targets. jit_specs.append(gen_cutlass_fused_moe_sm120_module()) jit_specs.append(gen_gemm_sm120_module()) jit_specs.append(gen_gemm_sm120_module_cutlass_fp4()) jit_specs.append(gen_gemm_sm120_module_cutlass_mxfp8()) + jit_specs.append(gen_trtllm_fmha_v2_sm120_module()) if has_sm120f: jit_specs.append(gen_fp4_quantization_sm120f_module())