diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index abab9cfc48..212978cbae 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -15,7 +15,7 @@ from aiter import get_hip_quant as get_quant from aiter import logger from aiter.jit.core import ( - AITER_CONFIG_FMOE_FILE, + AITER_CONFIGS, PY, bd_dir, get_asm_dir, @@ -573,8 +573,8 @@ def get_cfg_2stages(tune_file): return cfg_2stages global cfg_2stages - config_path = os.path.dirname(AITER_CONFIG_FMOE_FILE) - tune_file = AITER_CONFIG_FMOE_FILE + config_path = os.path.dirname(AITER_CONFIGS.AITER_CONFIG_FMOE_FILE) + tune_file = AITER_CONFIGS.AITER_CONFIG_FMOE_FILE untune_file = os.path.join(config_path, "untuned_fmoe.csv") profile_file = os.path.join(config_path, "profile_fmoe.csv") if cfg_2stages is None: diff --git a/aiter/jit/core.py b/aiter/jit/core.py index 0c0e8a49d6..df7241ce24 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -66,81 +66,6 @@ def mp_lock( # config_env start here -def update_config_files(file_path: str, merge_name: str): - path_list = file_path.split(os.pathsep) if file_path else [] - if len(path_list) <= 1: - return file_path - df_list = [] - ## merge config files - ##example: AITER_CONFIG_GEMM_A4W4="/path1:/path2" - import pandas as pd - - df_list.append(pd.read_csv(path_list[0])) - for i, path in enumerate(path_list[1:]): - if os.path.exists(path): - df = pd.read_csv(path) - ## check columns - assert ( - df.columns.tolist() == df_list[0].columns.tolist() - ), f"Column mismatch between {path_list[0]} and {path}, {df_list[0].columns.tolist()}, {df.columns.tolist()}" - - df_list.append(df) - else: - logger.info(f"path {i+1}: {path} (not exist)") - merge_df = pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame() - ## get keys from untuned file to drop_duplicates - untuned_name = ( - re.sub(r"(?:_)?tuned$", r"\1untuned", merge_name) - if re.search(r"(?:_)?tuned$", merge_name) - else merge_name.replace("tuned", "untuned") - ) - untuned_path = f"{AITER_ROOT_DIR}/aiter/configs/{untuned_name}.csv" - if os.path.exists(untuned_path): - untunedf = pd.read_csv(untuned_path) - keys = untunedf.columns.to_list() - keys.append("cu_num") - merge_df = ( - merge_df.sort_values("us") - .drop_duplicates(subset=keys, keep="first") - .reset_index(drop=True) - ) - else: - logger.warning( - f"Untuned config file not found: {untuned_path}. Using all columns for deduplication." - ) - new_file_path = f"/tmp/{merge_name}.csv" - merge_df.to_csv(new_file_path, index=False) - return new_file_path - - -def get_config_file(env_name, default_file, tuned_file_name): - config_env_file = os.getenv(env_name) - # default_file = f"{AITER_ROOT_DIR}/aiter/configs/{tuned_file_name}.csv" - from pathlib import Path - - if not config_env_file: - model_config_dir = Path(f"{AITER_ROOT_DIR}/aiter/configs/model_configs/") - op_tuned_file_list = [ - p - for p in model_config_dir.glob(f"*{tuned_file_name}*") - if (p.is_file() and "untuned" not in str(p)) - ] - - if not op_tuned_file_list: - config_file = default_file - else: - tuned_files = ":".join(str(p) for p in op_tuned_file_list) - tuned_files = default_file + ":" + tuned_files - logger.info( - f"merge tuned file under model_configs/ and configs/ {tuned_files}" - ) - config_file = update_config_files(tuned_files, tuned_file_name) - else: - config_file = update_config_files(config_env_file, tuned_file_name) - # print(f"get config file from environment ", config_file) - return config_file - - AITER_CONFIG_GEMM_A4W4 = os.getenv( "AITER_CONFIG_GEMM_A4W4", f"{AITER_ROOT_DIR}/aiter/configs/a4w4_blockscale_tuned_gemm.csv", @@ -185,56 +110,170 @@ def get_config_file(env_name, default_file, tuned_file_name): AITER_CONFIG_GEMM_BF16 = os.getenv( "AITER_CONFIG_GEMM_BF16", - f"{AITER_ROOT_DIR}/aiter/configs/tuned_gemm.csv", -) -AITER_CONFIG_GEMM_A4W4_FILE = get_config_file( - "AITER_CONFIG_GEMM_A4W4", AITER_CONFIG_GEMM_A4W4, "a4w4_blockscale_tuned_gemm" + f"{AITER_ROOT_DIR}/aiter/configs/bf16_tuned_gemm.csv", ) -AITER_CONFIG_GEMM_A8W8_FILE = get_config_file( - "AITER_CONFIG_GEMM_A8W8", AITER_CONFIG_GEMM_A8W8, "a8w8_tuned_gemm" -) -AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE = get_config_file( - "AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE", - AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE, - "a8w8_bpreshuffle_tuned_gemm", -) -AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE_FILE = get_config_file( - "AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE", - AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE, - "a8w8_bpreshuffle_cktile_tuned_gemm", -) -AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE = get_config_file( - "AITER_CONFIG_GEMM_A8W8_BLOCKSCALE", - AITER_CONFIG_GEMM_A8W8_BLOCKSCALE, - "a8w8_blockscale_tuned_gemm", -) -AITER_CONFIG_FMOE_FILE = get_config_file( - "AITER_CONFIG_FMOE", AITER_CONFIG_FMOE, "tuned_fmoe" -) -AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE = get_config_file( - "AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE", - AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE, - "a8w8_blockscale_bpreshuffle_tuned_gemm", -) +class AITER_CONFIG(object): + @property + def AITER_CONFIG_GEMM_A4W4_FILE(self): + return self.get_config_file( + "AITER_CONFIG_GEMM_A4W4", + AITER_CONFIG_GEMM_A4W4, + "a4w4_blockscale_tuned_gemm", + ) -AITER_CONFIG_A8W8_BATCHED_GEMM_FILE = get_config_file( - "AITER_CONFIG_A8W8_BATCHED_GEMM", - AITER_CONFIG_A8W8_BATCHED_GEMM, - "a8w8_tuned_batched_gemm", -) + @property + def AITER_CONFIG_GEMM_A8W8_FILE(self): + return self.get_config_file( + "AITER_CONFIG_GEMM_A8W8", AITER_CONFIG_GEMM_A8W8, "a8w8_tuned_gemm" + ) -AITER_CONFIG_BF16_BATCHED_GEMM_FILE = get_config_file( - "AITER_CONFIG_BF16_BATCHED_GEMM", - AITER_CONFIG_BF16_BATCHED_GEMM, - "bf16_tuned_batched_gemm", -) + @property + def AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE(self): + return self.get_config_file( + "AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE", + AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE, + "a8w8_bpreshuffle_tuned_gemm", + ) + + @property + def AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE(self): + return self.get_config_file( + "AITER_CONFIG_GEMM_A8W8_BLOCKSCALE", + AITER_CONFIG_GEMM_A8W8_BLOCKSCALE, + "a8w8_blockscale_tuned_gemm", + ) + + @property + def AITER_CONFIG_FMOE_FILE(self): + return self.get_config_file( + "AITER_CONFIG_FMOE", AITER_CONFIG_FMOE, "tuned_fmoe" + ) + + @property + def AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE(self): + return self.get_config_file( + "AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE", + AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE, + "a8w8_blockscale_bpreshuffle_tuned_gemm", + ) + + @property + def AITER_CONFIG_A8W8_BATCHED_GEMM_FILE(self): + return self.get_config_file( + "AITER_CONFIG_A8W8_BATCHED_GEMM", + AITER_CONFIG_A8W8_BATCHED_GEMM, + "a8w8_tuned_batched_gemm", + ) + + @property + def AITER_CONFIG_BF16_BATCHED_GEMM_FILE(self): + return self.get_config_file( + "AITER_CONFIG_BF16_BATCHED_GEMM", + AITER_CONFIG_BF16_BATCHED_GEMM, + "bf16_tuned_batched_gemm", + ) + + @property + def AITER_CONFIG_GEMM_BF16_FILE(self): + return self.get_config_file( + "AITER_CONFIG_GEMM_BF16", AITER_CONFIG_GEMM_BF16, "bf16_tuned_gemm" + ) + + @property + def AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE_FILE(self): + return self.get_config_file( + "AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE", + AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE, + "a8w8_bpreshuffle_cktile_tuned_gemm", + ) + + def update_config_files(self, file_path: str, merge_name: str): + path_list = file_path.split(os.pathsep) if file_path else [] + if len(path_list) <= 1: + return file_path + df_list = [] + ## merge config files + ##example: AITER_CONFIG_GEMM_A4W4="/path1:/path2" + import pandas as pd + + df_list.append(pd.read_csv(path_list[0])) + for i, path in enumerate(path_list[1:]): + if os.path.exists(path): + df = pd.read_csv(path) + ## check columns + assert ( + df.columns.tolist() == df_list[0].columns.tolist() + ), f"Column mismatch between {path_list[0]} and {path}, {df_list[0].columns.tolist()}, {df.columns.tolist()}" + + df_list.append(df) + else: + logger.info(f"path {i+1}: {path} (not exist)") + merge_df = pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame() + ## get keys from untuned file to drop_duplicates + untuned_name = ( + re.sub(r"(?:_)?tuned$", r"\1untuned", merge_name) + if re.search(r"(?:_)?tuned$", merge_name) + else merge_name.replace("tuned", "untuned") + ) + untuned_path = f"{AITER_ROOT_DIR}/aiter/configs/{untuned_name}.csv" + if os.path.exists(untuned_path): + untunedf = pd.read_csv(untuned_path) + keys = untunedf.columns + merge_df = ( + merge_df.sort_values("us") + .drop_duplicates(subset=keys, keep="first") + .reset_index(drop=True) + ) + else: + logger.warning( + f"Untuned config file not found: {untuned_path}. Using all columns for deduplication." + ) + from pathlib import Path + + config_path = Path("/tmp/aiter_configs/") + if not config_path.exists(): + config_path.mkdir(parents=True, exist_ok=True) + new_file_path = f"{config_path}/{merge_name}.csv" + lock_path = f"{new_file_path}.lock" + + def write_config(): + merge_df.to_csv(new_file_path, index=False) + + mp_lock(lock_path, write_config) + return new_file_path + + @functools.lru_cache(maxsize=20) + def get_config_file(self, env_name, default_file, tuned_file_name): + config_env_file = os.getenv(env_name) + # default_file = f"{AITER_ROOT_DIR}/aiter/configs/{tuned_file_name}.csv" + from pathlib import Path + + if not config_env_file: + model_config_dir = Path(f"{AITER_ROOT_DIR}/aiter/configs/model_configs/") + op_tuned_file_list = [ + p + for p in model_config_dir.glob(f"*{tuned_file_name}*") + if (p.is_file() and "untuned" not in str(p)) + ] + + if not op_tuned_file_list: + config_file = default_file + else: + tuned_files = ":".join(str(p) for p in op_tuned_file_list) + tuned_files = default_file + ":" + tuned_files + logger.info( + f"merge tuned file under model_configs/ and configs/ {tuned_files}" + ) + config_file = self.update_config_files(tuned_files, tuned_file_name) + else: + config_file = self.update_config_files(config_env_file, tuned_file_name) + # print(f"get config file from environment ", config_file) + return config_file -AITER_CONFIG_GEMM_BF16_FILE = get_config_file( - "AITER_CONFIG_GEMM_BF16", AITER_CONFIG_GEMM_BF16, "bf16_tuned_gemm" -) +AITER_CONFIGS = AITER_CONFIG() # config_env end here find_aiter = importlib.util.find_spec("aiter") diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index 0ef416b8e9..51b271f00f 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -188,7 +188,7 @@ "extra_include": [ "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/include'" ], - "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_BF16_BATCHED_GEMM_FILE}'" + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIGS.AITER_CONFIG_BF16_BATCHED_GEMM_FILE}'" }, "module_batched_gemm_a8w8": { "srcs": [ @@ -198,7 +198,7 @@ "extra_include": [ "f'{AITER_CSRC_DIR}/ck_batched_gemm_a8w8/include'" ], - "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_batched_gemm_a8w8/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_A8W8_BATCHED_GEMM_FILE}'" + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_batched_gemm_a8w8/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIGS.AITER_CONFIG_A8W8_BATCHED_GEMM_FILE}'" }, "module_gemm_a8w8": { "srcs": [ @@ -208,7 +208,7 @@ "extra_include": [ "f'{AITER_CSRC_DIR}/ck_gemm_a8w8/include'" ], - "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A8W8_FILE}'" + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_FILE}'" }, "module_gemm_a8w8_blockscale": { "srcs": [ @@ -223,7 +223,7 @@ "'-mllvm -greedy-reverse-local-assignment=1'", "'-mllvm --amdgpu-use-amdgpu-trackers=1'" ], - "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE}'" + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE}'" }, "module_gemm_a8w8_blockscale_bpreshuffle": { "srcs": [ @@ -238,7 +238,7 @@ "'-mllvm -greedy-reverse-local-assignment=1'", "'-mllvm --amdgpu-use-amdgpu-trackers=1'" ], - "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_bpreshuffle/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE}'" + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_bpreshuffle/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE}'" }, "module_gemm_a4w4_blockscale": { "srcs": [ @@ -254,7 +254,7 @@ "'-mllvm --amdgpu-use-amdgpu-trackers=1'" ], "hip_clang_path": "os.environ.get('GEMM_A4W4_BLOCKWISE_HIP_CLANG_PATH')", - "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a4w4_blockscale/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A4W4_FILE}'" + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a4w4_blockscale/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIGS.AITER_CONFIG_GEMM_A4W4_FILE}'" }, "module_gemm_a8w8_bpreshuffle": { "srcs": [ @@ -267,7 +267,7 @@ ], "is_python_module": "True", "is_standalone": "False", - "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_bpreshuffle/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE}'" + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_bpreshuffle/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE}'" }, "module_deepgemm": { "srcs": [ @@ -305,7 +305,7 @@ "is_standalone": "False", "verbose": "False", "hip_clang_path": "os.environ.get('FLATMM_HIP_CLANG_PATH')", - "blob_gen_cmd": "f'{AITER_CSRC_DIR}/cktile_gemm_a8w8_bpreshuffle/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE_FILE}'" + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/cktile_gemm_a8w8_bpreshuffle/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE_FILE}'" }, "module_gemm_a8w8_asm": { "srcs": [ diff --git a/aiter/ops/batched_gemm_op_a8w8.py b/aiter/ops/batched_gemm_op_a8w8.py index 4c36049883..cbab06e679 100644 --- a/aiter/ops/batched_gemm_op_a8w8.py +++ b/aiter/ops/batched_gemm_op_a8w8.py @@ -9,7 +9,7 @@ from ..jit.core import ( compile_ops, AITER_ROOT_DIR, - AITER_CONFIG_A8W8_BATCHED_GEMM_FILE, + AITER_CONFIGS, AITER_LOG_TUNED_CONFIG, ) from ..utility import dtypes @@ -66,9 +66,12 @@ def get_CKBatchedGEMM_config( K: int, ): if not hasattr(get_CKBatchedGEMM_config, "ck_batched_gemm_dict"): - print("Loading CKBatchedGEMM config from:", AITER_CONFIG_A8W8_BATCHED_GEMM_FILE) + print( + "Loading CKBatchedGEMM config from:", + AITER_CONFIGS.AITER_CONFIG_A8W8_BATCHED_GEMM_FILE, + ) ck_batched_gemm_dict = pd.read_csv( - AITER_CONFIG_A8W8_BATCHED_GEMM_FILE + AITER_CONFIGS.AITER_CONFIG_A8W8_BATCHED_GEMM_FILE ).drop_duplicates() get_CKBatchedGEMM_config.ck_batched_gemm_dict = ck_batched_gemm_dict.set_index( @@ -81,7 +84,7 @@ def get_CKBatchedGEMM_config( if config is not None: if AITER_LOG_TUNED_CONFIG: logger.info( - f"shape is B:{B}, M:{M}, N:{N}, K:{K}, is tuned on cu_num = {cu_num} in {AITER_CONFIG_A8W8_BATCHED_GEMM_FILE}, kernel name is {config['kernelName']}, splitK is {config['splitK']}!" + f"shape is B:{B}, M:{M}, N:{N}, K:{K}, is tuned on cu_num = {cu_num} in {AITER_CONFIGS.AITER_CONFIG_A8W8_BATCHED_GEMM_FILE}, kernel name is {config['kernelName']}, splitK is {config['splitK']}!" ) mnk = config["kernelName"].split("_")[3].split("x")[1:] config["tile_m"] = int(mnk[0]) diff --git a/aiter/ops/batched_gemm_op_bf16.py b/aiter/ops/batched_gemm_op_bf16.py index 66b2fd1caa..43625fa429 100644 --- a/aiter/ops/batched_gemm_op_bf16.py +++ b/aiter/ops/batched_gemm_op_bf16.py @@ -9,7 +9,7 @@ from ..jit.core import ( compile_ops, AITER_ROOT_DIR, - AITER_CONFIG_BF16_BATCHED_GEMM_FILE, + AITER_CONFIGS, AITER_LOG_TUNED_CONFIG, ) from ..utility import dtypes @@ -56,7 +56,7 @@ def get_CKBatchedGEMM_config( ): if not hasattr(get_CKBatchedGEMM_config, "ck_batched_gemm_dict"): ck_batched_gemm_dict = pd.read_csv( - AITER_CONFIG_BF16_BATCHED_GEMM_FILE + AITER_CONFIGS.AITER_CONFIG_BF16_BATCHED_GEMM_FILE ).drop_duplicates() get_CKBatchedGEMM_config.ck_batched_gemm_dict = ck_batched_gemm_dict.set_index( ["cu_num", "B", "M", "N", "K"] @@ -68,7 +68,7 @@ def get_CKBatchedGEMM_config( if config is not None: if AITER_LOG_TUNED_CONFIG: logger.info( - f"shape is B:{B}, M:{M}, N:{N}, K:{K} dtype is bf16, is tuned on cu_num = {cu_num} in {AITER_CONFIG_BF16_BATCHED_GEMM_FILE}, kernel name is {config['kernelName']}, splitK is {config['splitK']}!" + f"shape is B:{B}, M:{M}, N:{N}, K:{K} dtype is bf16, is tuned on cu_num = {cu_num} in {AITER_CONFIGS.AITER_CONFIG_BF16_BATCHED_GEMM_FILE}, kernel name is {config['kernelName']}, splitK is {config['splitK']}!" ) mnk = config["kernelName"].split("_")[2].split("x")[1:] config["tile_m"] = int(mnk[0]) diff --git a/aiter/ops/gemm_op_a4w4.py b/aiter/ops/gemm_op_a4w4.py index 8ec03cea1e..bd3759f98c 100644 --- a/aiter/ops/gemm_op_a4w4.py +++ b/aiter/ops/gemm_op_a4w4.py @@ -12,7 +12,7 @@ from aiter import logger from ..jit.core import ( - AITER_CONFIG_GEMM_A4W4_FILE, + AITER_CONFIGS, AITER_LOG_TUNED_CONFIG, compile_ops, ) @@ -36,7 +36,9 @@ def compute_gemm_SplitK(M: int, N: int, K: int, tile_m: int, tile_n: int, tile_k @functools.lru_cache(maxsize=1024) def get_GEMM_config(M: int, N: int, K: int): if not hasattr(get_GEMM_config, "gemm_dict"): - gemm_dict = pd.read_csv(AITER_CONFIG_GEMM_A4W4_FILE).drop_duplicates() + gemm_dict = pd.read_csv( + AITER_CONFIGS.AITER_CONFIG_GEMM_A4W4_FILE + ).drop_duplicates() get_GEMM_config.gemm_dict = gemm_dict.set_index( ["cu_num", "M", "N", "K"] ).to_dict("index") @@ -49,7 +51,7 @@ def get_GEMM_config(M: int, N: int, K: int): if config is not None: if AITER_LOG_TUNED_CONFIG: logger.info( - f"shape is M:{M}, N:{N}, K:{K}, found padded_M: {padded_M}, N:{N}, K:{K} is tuned on cu_num = {cu_num} in {AITER_CONFIG_GEMM_A4W4_FILE}, kernel name is {config['kernelName']}, splitK is {config['splitK']}!" + f"shape is M:{M}, N:{N}, K:{K}, found padded_M: {padded_M}, N:{N}, K:{K} is tuned on cu_num = {cu_num} in {AITER_CONFIGS.AITER_CONFIG_GEMM_A4W4_FILE}, kernel name is {config['kernelName']}, splitK is {config['splitK']}!" ) break else: diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index 6cd2da4758..6a20086881 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -10,11 +10,7 @@ from ..jit.core import ( compile_ops, AITER_ROOT_DIR, - AITER_CONFIG_GEMM_A8W8_FILE, - AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE, - AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE, - AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE_FILE, - AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE, + AITER_CONFIGS, AITER_LOG_TUNED_CONFIG, ) from ..jit.utils.torch_guard import torch_compile_guard @@ -367,7 +363,11 @@ def gemm_a8w8_ASM( and w_scale.dtype == dtypes.fp32 and ( asm_config := get_bpreshuffle_GEMM_config( - m, n, k, dtypes.i8, AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE + m, + n, + k, + dtypes.i8, + AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE, ) ) is not None @@ -402,7 +402,7 @@ def gemm_a8w8_CK( m = XQ.shape[0] n = WQ.shape[0] k = XQ.shape[-1] - ck_config = get_CKGEMM_config(m, n, k, AITER_CONFIG_GEMM_A8W8_FILE) + ck_config = get_CKGEMM_config(m, n, k, AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_FILE) if splitK is None: if ck_config is not None: splitK = ck_config["splitK"] @@ -458,13 +458,17 @@ def gemm_a8w8_bpreshuffle( # CKTile only supports bf16 dtype if dtype == dtypes.bf16: cktile_config = get_bpreshuffle_GEMM_config( - m, n, k, dtypes.fp8, AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE_FILE + m, + n, + k, + dtypes.fp8, + AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_CKTILE_FILE, ) else: cktile_config = None ck_config = get_bpreshuffle_GEMM_config( - m, n, k, dtypes.fp8, AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE + m, n, k, dtypes.fp8, AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE ) if cktile_config is not None and ck_config is not None: cktile_time = cktile_config.get("us", float("inf")) @@ -529,7 +533,7 @@ def gemm_a8w8_blockscale( else: assert 0, "asm kernel only support B preshuffle and m >= 16" else: - get_CKGEMM_config(m, n, k, AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE) + get_CKGEMM_config(m, n, k, AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE) return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y) @@ -575,7 +579,9 @@ def gemm_a8w8_blockscale_bpreshuffle( m = XQ.shape[0] n = WQ.shape[0] k = XQ.shape[1] - get_CKGEMM_config(m, n, k, AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE) + get_CKGEMM_config( + m, n, k, AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE + ) Y = torch.empty(m, n, dtype=dtype, device=XQ.device) return gemm_a8w8_blockscale_bpreshuffle_ck(XQ, WQ, x_scale, w_scale, Y) diff --git a/aiter/tuned_gemm.py b/aiter/tuned_gemm.py index a0f5122220..719aa77ebd 100644 --- a/aiter/tuned_gemm.py +++ b/aiter/tuned_gemm.py @@ -31,7 +31,7 @@ hipb_mm, logger, ) -from aiter.jit.core import AITER_CONFIG_GEMM_BF16_FILE, AITER_LOG_TUNED_CONFIG +from aiter.jit.core import AITER_CONFIGS, AITER_LOG_TUNED_CONFIG from aiter.jit.utils.chip_info import get_cu_num from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.gemm_op_common import get_padded_m @@ -55,7 +55,7 @@ def get_solfunc(soltype: int): @functools.lru_cache(maxsize=1) def get_GEMM_A16W16_config_(): - tuned_file = AITER_CONFIG_GEMM_BF16_FILE + tuned_file = AITER_CONFIGS.AITER_CONFIG_GEMM_BF16_FILE gemm_dict = {} if os.path.exists(tuned_file): gemm_dict = pd.read_csv(f"{tuned_file}").drop_duplicates() @@ -82,13 +82,13 @@ def get_GEMM_A16W16_config( if AITER_LOG_TUNED_CONFIG: kernelName = config["kernelName"] if config["libtype"] == "asm" else "" logger.info( - f"shape is M:{M}, N:{N}, K:{K} {dtype=} {otype=} {bias=}, {scaleAB=}, found padded_M: {padded_M}, N:{N}, K:{K} is tuned on cu_num = {cu_num} in {AITER_CONFIG_GEMM_BF16_FILE}, libtype is {config['libtype']}, kernel name is {kernelName}" + f"shape is M:{M}, N:{N}, K:{K} {dtype=} {otype=} {bias=}, {scaleAB=}, found padded_M: {padded_M}, N:{N}, K:{K} is tuned on cu_num = {cu_num} in {AITER_CONFIGS.AITER_CONFIG_GEMM_BF16_FILE}, libtype is {config['libtype']}, kernel name is {kernelName}" ) return config if config is None: default_config = {} logger.info( - f"shape is M:{M}, N:{N}, K:{K}, not found tuned config in {AITER_CONFIG_GEMM_BF16_FILE}, will use default config!" + f"shape is M:{M}, N:{N}, K:{K}, not found tuned config in {AITER_CONFIGS.AITER_CONFIG_GEMM_BF16_FILE}, will use default config!" ) if dtype in [dtypes.fp16, dtypes.bf16] and K % 8 == 0: if ( @@ -288,7 +288,7 @@ def __init__(self): self.extensions_created = False self.save_gemm = int(os.environ.get("AITER_TUNE_GEMM", 0)) self.untune_path = f"{this_dir}/configs/bf16_untuned_gemm.csv" - self.tune_path = AITER_CONFIG_GEMM_BF16_FILE + self.tune_path = AITER_CONFIGS.AITER_CONFIG_GEMM_BF16_FILE def mm( self, diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile.cu b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile.cu index 54727af540..44a9336f2d 100755 --- a/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile.cu +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile.cu @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_a8w8_bpreshuffle_cktile_common.cuh" #include "gemm_a8w8_bpreshuffle_cktile_lookup.h" diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_common.py b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_common.py index 670e73894a..23cc114a55 100644 --- a/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_common.py +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_common.py @@ -97,7 +97,7 @@ def name(self) -> str: 19: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 256, 128, 1, 4, 1, 16, 16, 64, "Default"), 20: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 256, 128, 1, 4, 1, 16, 16, 64, "Default"), 21: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), - 22: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 22: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), 23: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), 24: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), 25: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), @@ -153,7 +153,7 @@ def name(self) -> str: 75: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), 76: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), 77: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), - 78: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 78: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), 79: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), 80: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), 81: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), @@ -168,7 +168,7 @@ def name(self) -> str: 90: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 256, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), 91: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), 92: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 80, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), - 93: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 224, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 93: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 224, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), 94: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 112, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), 95: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), 96: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 192, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), @@ -185,7 +185,7 @@ def name(self) -> str: 107: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), 108: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), 109: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), - 110: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), + 110: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 64, 256, 1, 4, 1, 16, 16, 64, "Default"), 111: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), 112: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), 113: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), @@ -201,7 +201,7 @@ def name(self) -> str: 123: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 256, 192, 128, 1, 4, 1, 16, 16, 64, "Default"), 124: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 256, 256, 1, 4, 1, 16, 16, 64, "Default"), 125: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 80, 128, 256, 1, 4, 1, 16, 16, 64, "Default"), - 126: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 224, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), + 126: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 224, 64, 128, 1, 4, 1, 16, 16, 64, "Default"), 127: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 112, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), 128: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 192, 256, 1, 4, 1, 16, 16, 64, "Default"), 129: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 192, 128, 128, 1, 4, 1, 16, 16, 64, "Default"), @@ -251,7 +251,7 @@ def name(self) -> str: 19: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 256, 128, 1, 4, 1, 16, 16, 128, "Default"), 20: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 256, 128, 1, 4, 1, 16, 16, 128, "Default"), 21: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), - 22: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 22: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), 23: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), 24: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 64, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), 25: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 32, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), @@ -307,7 +307,7 @@ def name(self) -> str: 75: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), 76: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), 77: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), - 78: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 78: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), 79: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), 80: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), 81: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 96, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), @@ -322,7 +322,7 @@ def name(self) -> str: 90: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 256, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), 91: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 48, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), 92: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 80, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), - 93: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 224, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 93: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 224, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), 94: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 112, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), 95: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 128, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), 96: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 1, 192, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), @@ -339,7 +339,7 @@ def name(self) -> str: 107: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), 108: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), 109: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), - 110: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), + 110: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 64, 256, 1, 4, 1, 16, 16, 128, "Default"), 111: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), 112: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), 113: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), @@ -355,14 +355,14 @@ def name(self) -> str: 123: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 256, 192, 128, 1, 4, 1, 16, 16, 128, "Default"), 124: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 48, 256, 256, 1, 4, 1, 16, 16, 128, "Default"), 125: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 80, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), - 126: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 224, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), + 126: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 224, 64, 128, 1, 4, 1, 16, 16, 128, "Default"), 127: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 112, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), 128: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 128, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), 129: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 192, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), 130: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 224, 128, 128, 1, 4, 1, 16, 16, 128, "Default"), 131: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 192, 256, 1, 4, 1, 16, 16, 128, "Default"), 132: kernelInstance( 0, 0, 8, 4, 1, 0, 0, 0, 0, 2, 96, 128, 256, 1, 4, 1, 16, 16, 128, "Default"), - + } diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.cu b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.cu index b7658841ff..787915bd0f 100644 --- a/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.cu +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.cu @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_a8w8_bpreshuffle_cktile_common.cuh" #include "gemm_a8w8_bpreshuffle_cktile_lookup.h" diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/gen_instances.py b/csrc/cktile_gemm_a8w8_bpreshuffle/gen_instances.py index 656a273d80..3dec52da3a 100755 --- a/csrc/cktile_gemm_a8w8_bpreshuffle/gen_instances.py +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/gen_instances.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. import os import sys from dataclasses import dataclass @@ -16,9 +16,9 @@ ) -""" +""" -gemm_a8w8_bpreshuffle_cktile instance gen +gemm_a8w8_bpreshuffle_cktile instance gen """ @@ -69,9 +69,9 @@ def gen_instance(self, k: kernelInstance): """ - INSTANCE_CONTENT_nobias = f"""using FlatmmInstance = CustomConfig< - DDataType, EDataType, - {k.sTransposeC},{k.sUseStructuredSparsity}, {k.sTileParitionerGroupNum}, + INSTANCE_CONTENT_nobias = f"""using FlatmmInstance = CustomConfig< + DDataType, EDataType, + {k.sTransposeC},{k.sUseStructuredSparsity}, {k.sTileParitionerGroupNum}, {k.sTileParitionerM01}, {k.sNumWaveGroups}, {k.sDoubleSmemBuffer}, {k.PadM}, {k.PadN}, {k.PadK}, {k.BlockPerCu}, diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile.h b/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile.h index 2eb83d065f..eee22711f6 100644 --- a/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile.h +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile.h @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include torch::Tensor gemm_a8w8_bpreshuffle_cktile( diff --git a/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_pybind.cu b/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_pybind.cu index b453764779..6441b4e862 100644 --- a/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_pybind.cu +++ b/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_pybind.cu @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_a8w8_bpreshuffle_cktile.h" #include "rocm_ops.hpp" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { GEMM_A8W8_BPRESHUFFLE_CKTILE_PYBIND; } diff --git a/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_tune_pybind.cu b/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_tune_pybind.cu index aaa0ba69f7..ced86c1a12 100644 --- a/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_tune_pybind.cu +++ b/csrc/pybind/gemm_a8w8_bpreshuffle_cktile_tune_pybind.cu @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_a8w8_bpreshuffle_cktile.h" #include "rocm_ops.hpp" diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/gradlib/GemmTuner.py index 44ee280019..f8a6ec44aa 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/gradlib/GemmTuner.py @@ -452,7 +452,7 @@ class GemmTuner(GemmCommonTuner): ARG_DEFAULTS = { **GemmCommonTuner.ARG_DEFAULTS, "tune_file": f"{AITER_CONFIG_GEMM_BF16}", - "untune_file": "aiter/configs/untuned_gemm.csv", + "untune_file": "aiter/configs/bf16_untuned_gemm.csv", "batch": 1, }