diff --git a/csrc/ck_tile_gemm_moe_2stages/gen_instances.py b/csrc/ck_tile_gemm_moe_2stages/gen_instances.py index bc76068637..c3a600a370 100644 --- a/csrc/ck_tile_gemm_moe_2stages/gen_instances.py +++ b/csrc/ck_tile_gemm_moe_2stages/gen_instances.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import os import argparse from pathlib import Path @@ -12,7 +12,7 @@ get_heuristic_dispatch_template, ) import sys -from chip_info import get_gfx +from chip_info import get_gfx, get_gfx_list this_dir = os.path.dirname(os.path.abspath(__file__)) AITER_CORE_DIR = os.path.abspath(f"{this_dir}/../../../") @@ -222,15 +222,20 @@ def gen_instance(self, k: kernelInstance, a_type): # else: def fill_template(name, a_type, b_type, acc_type, c_type): nonlocal self - intsance = INSTANCE_template.format( + # Arch-aware scheduling: skip generating FP4 instances unless gfx950 is targeted + if "fp4" in b_type and ("gfx950" not in get_gfx_list()): + return + body = INSTANCE_template.format( name=name, dtypes=f"{a_type}, {b_type}, {acc_type}, {c_type}" ) + if "fp4" in b_type: + body = "#ifndef __gfx942__\n" + body + "\n#endif\n" Path( os.path.join( self.instances_path, f"{name}_a{a_type}_b{b_type}_acc{acc_type}_C{c_type}.cpp", ) - ).write_text(intsance) + ).write_text(body) if (k.QuantType == "1x32") and (a_type in ["bf16", "fp16", "fp8"]): fill_template(k.name, a_type, "pk_fp4", self.acc_dtype, self.c_dtype)