Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions csrc/ck_tile_gemm_moe_2stages/gen_instances.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
import os
import argparse
from pathlib import Path
Expand All @@ -12,7 +12,7 @@
get_heuristic_dispatch_template,
)
import sys
from chip_info import get_gfx
from chip_info import get_gfx, get_gfx_list

this_dir = os.path.dirname(os.path.abspath(__file__))
AITER_CORE_DIR = os.path.abspath(f"{this_dir}/../../../")
Expand Down Expand Up @@ -222,15 +222,20 @@ def gen_instance(self, k: kernelInstance, a_type):
# else:
def fill_template(name, a_type, b_type, acc_type, c_type):
nonlocal self
intsance = INSTANCE_template.format(
# Arch-aware scheduling: skip generating FP4 instances unless gfx950 is targeted
if "fp4" in b_type and ("gfx950" not in get_gfx_list()):
return
body = INSTANCE_template.format(
name=name, dtypes=f"{a_type}, {b_type}, {acc_type}, {c_type}"
)
if "fp4" in b_type:
body = "#ifndef __gfx942__\n" + body + "\n#endif\n"
Path(
os.path.join(
self.instances_path,
f"{name}_a{a_type}_b{b_type}_acc{acc_type}_C{c_type}.cpp",
)
).write_text(intsance)
).write_text(body)

if (k.QuantType == "1x32") and (a_type in ["bf16", "fp16", "fp8"]):
fill_template(k.name, a_type, "pk_fp4", self.acc_dtype, self.c_dtype)
Expand Down