diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index d5b7e0512e..d8fca4b4c9 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -909,6 +909,12 @@ def FinalFunc(): kernelName1 = cfg["kernelName1"] kernelName2 = cfg["kernelName2"] run_1stage = cfg.get("run_1stage", False) + if not is_shuffled and not run_1stage: + logger.warning( + f"[fused_moe] tuned config found for {keys} but is_shuffled=False. " + "Tuned kernels are optimized for preshuffled weights (preshuffle_on). " + "Running with preshuffle_off may produce incorrect results." + ) tag = f"({kernelName1=}, {kernelName2=})" logger.info( diff --git a/aiter/utility/base_tuner.py b/aiter/utility/base_tuner.py index a580a6ddb8..aefeafdca7 100644 --- a/aiter/utility/base_tuner.py +++ b/aiter/utility/base_tuner.py @@ -3,6 +3,8 @@ import os import sys +import shutil +import tempfile import argparse import torch import pandas as pd @@ -16,6 +18,20 @@ INVALID_TIME = -1 +def _read_csv(filepath, **kwargs): + """Read CSV with automatic cleanup of common formatting issues: + trailing tabs/spaces, extra unnamed columns, whitespace in headers/values. + """ + df = pd.read_csv(filepath, **kwargs) + df.columns = df.columns.str.strip() + df = df.loc[:, ~df.columns.str.startswith("Unnamed:")] + str_cols = df.select_dtypes(include=["object"]).columns + for col in str_cols: + df[col] = df[col].apply(lambda v: v.strip() if isinstance(v, str) else v) + df.dropna(how="all", inplace=True) + return df + + class TunerCommon: ARG_DEFAULTS = { "verbose": False, @@ -27,6 +43,7 @@ class TunerCommon: "timeout": None, # 100s timeout for per test "warmup": 5, # 5 warmup iters for profiling "iters": 101, # 101 run iters for profiling + "min_improvement_pct": 3.0, # only write shapes improved by >= N% } dtype2bpe_dict = { dtypes.fp16: 2, @@ -115,6 +132,14 @@ def _setup_common_arguments(self): required=False, help="Use splitK kernels", ) + self.parser.add_argument( + "--shape_grouped", + action="store_true", + default=False, + required=False, + help="Group all kernel candidates for the same shape onto one GPU " + "to eliminate cross-GPU timing variance (also saves generate_data calls)", + ) self.parser.add_argument( "--sort", type=dtypes.str2bool, @@ -165,9 +190,41 @@ def _setup_common_arguments(self): default=defaults["timeout"], help="timeout for task group", ) + self.parser.add_argument( + "--run_config", + nargs="?", + const=True, + default=False, + metavar="TUNED_CSV", + help="Run production operator benchmark and exit (no tuning). " + "If a tuned CSV path is given, read shapes and kernels from it; " + "otherwise read shapes from -i and run with default kernels.", + ) + self.parser.add_argument( + "--compare", + action="store_true", + required=False, + help="Run production-op benchmark before and after tuning, print compare results, and keep a compare candidate CSV.", + ) + self.parser.add_argument( + "--update_improved", + action="store_true", + required=False, + help="With --compare, update the final tuned CSV for shapes improved by at least --min_improvement_pct, or when pre-run has no valid baseline but post-run passes.", + ) + self.parser.add_argument( + "--min_improvement_pct", + dest="min_improvement_pct", + type=float, + default=defaults.get("min_improvement_pct", 3.0), + help="With --compare --update_improved, update tuned CSV only when a valid pre/post benchmark shows at least this percent improvement. Shapes with no valid pre-run baseline but passing post-run are still allowed to update.", + ) def parse_args(self): - return self.parser.parse_args() + args = self.parser.parse_args() + if args.update_improved and not args.compare: + self.parser.error("--update_improved requires --compare") + return args @abstractmethod def _setup_specific_arguments(self): @@ -207,10 +264,10 @@ def update_config_files(self, file_path: str, merge_name: str): ## merge config files ##example: AITER_CONFIG_GEMM_A4W4="/path1:/path2" - df_list.append(pd.read_csv(path_list[0])) + df_list.append(_read_csv(path_list[0])) for i, path in enumerate(path_list[1:]): if os.path.exists(path): - df = pd.read_csv(path) + df = _read_csv(path) base_cols = [c for c in df_list[0].columns if c != "_tag"] new_cols = [c for c in df.columns if c != "_tag"] assert ( @@ -238,7 +295,7 @@ def get_untuned_gemm_list(self, untuned_gemm_file): assert os.path.exists( untuned_gemm_file ), f"Not exist untuned file: {untuned_gemm_file}" - untunedf = pd.read_csv(untuned_gemm_file) + untunedf = _read_csv(untuned_gemm_file) filtered_df = untunedf.drop_duplicates().reset_index(drop=True) return filtered_df @@ -251,9 +308,14 @@ def get_out_file(self, tuned_file): def get_tuned_gemm_list(self, tuned_gemm_file, columns=[]): all_tuned_file = self.update_config_files(tuned_gemm_file, self.name) if os.path.exists(all_tuned_file): - column_order = pd.read_csv(all_tuned_file, nrows=0).columns.tolist() - tunedf = pd.read_csv(all_tuned_file) - tunedf = tunedf[column_order] + try: + column_order = _read_csv(all_tuned_file, nrows=0).columns.tolist() + tunedf = _read_csv(all_tuned_file) + tunedf = tunedf[column_order] + except pd.errors.EmptyDataError: + print(f"Empty tuned file: {all_tuned_file}") + columns = self.columns if not columns else columns + tunedf = pd.DataFrame(columns=columns) else: print(f"Not exist tuned file: {all_tuned_file}") columns = self.columns if not columns else columns @@ -321,7 +383,7 @@ def update_tunedf(self, df_old, df_updates): return df_old def sortResults(self, tune_file, issorted, values): - tunedf = pd.read_csv(tune_file) + tunedf = _read_csv(tune_file) if issorted: tunedf = tunedf.sort_values(by=values) dedup_keys = self.keys @@ -348,7 +410,7 @@ def post_process(self, rets, args, topk=-1, fast_mode=False): logger.info(f"saving profile to {args.profile_file}") profiledf = self.result_to_df(sorted(rets, key=itemgetter(0))) if os.path.exists(args.profile_file): - old_df = pd.read_csv(args.profile_file) + old_df = _read_csv(args.profile_file) else: old_df = pd.DataFrame(columns=self.columns) profiledf = pd.concat([old_df, profiledf], ignore_index=True) @@ -356,7 +418,6 @@ def post_process(self, rets, args, topk=-1, fast_mode=False): if fast_mode or topk == -1: return rets - tol_err_ratio = args.errRatio from collections import defaultdict grouped_rets = defaultdict(list) @@ -368,6 +429,7 @@ def post_process(self, rets, args, topk=-1, fast_mode=False): grouped_results = list(grouped_rets.items()) for info_key, time_list in grouped_results: + tol_err_ratio = args.errRatio sorted_time = sorted(time_list, key=lambda x: x[1]) filtered_time = [ (info_ex, round(us, 4), max_err_ratio) @@ -434,14 +496,709 @@ def update_tflops_bw(self, tune_file): """update tflops and bw from old tune_file""" pass + def run_config(self, args): + """Run the production operator for each shape in the untuned CSV. + Subclasses should override this to call the actual production operator. + Returns a list of dicts: [{"shape": str, "us": float, "status": "ok"/"error"}] + """ + logger.info(f"run_config not implemented for {self.name}, skipping benchmark") + return [] + + def _clear_op_caches(self): + """Clear operator-specific config caches. Subclasses should override this + to clear only their own caches.""" + pass + + def _set_config_env_for_run_config(self, args, config_file=None): + """Set the config env var to point to a tuned config file, clear caches, + and enable AITER_REBUILD so that run_config rebuilds with new configs. + *config_file* overrides the default (``-o`` / ``args.tune_file``). + """ + defaults = self.get_arg_defaults() + env_name = defaults.get("config_env_name") + if not env_name: + # Must return a 2-tuple: callers always unpack into old_val, old_rebuild. + return None, None + output_file = config_file if config_file else self.get_out_file(args.tune_file) + old_val = os.environ.get(env_name) + os.environ[env_name] = output_file + logger.info(f"Setting {env_name}={output_file} for benchmark") + # Clear operator-specific config caches + self._clear_op_caches() + # Enable AITER_REBUILD (level 2: rm .so only, keep build cache for faster rebuild) + # and clear module caches so operators rebuild with new config + from aiter.jit import core as jit_core + + old_rebuild = jit_core.AITER_REBUILD + jit_core.AITER_REBUILD = 2 + jit_core.get_module.cache_clear() + # Reset rebuilded_list so all modules get rebuilt on next call + jit_core.rebuilded_list = ["module_aiter_enum"] + # Clear loaded modules dict (use getattr to avoid Python name mangling of __ prefix in class methods) + mds = getattr(jit_core, "__mds", None) + if mds is not None: + mds.clear() + # Clear get_config_file lru_cache so it re-reads the env var + jit_core.AITER_CONFIGS.get_config_file.cache_clear() + return old_val, old_rebuild + + def _restore_config_env(self, env_name, old_val, old_rebuild=0): + """Restore the config env var and AITER_REBUILD to original values.""" + if env_name is None: + return + if old_val is None: + os.environ.pop(env_name, None) + else: + os.environ[env_name] = old_val + try: + from aiter.jit import core as jit_core + + jit_core.AITER_REBUILD = old_rebuild + except ImportError: + pass + + def _emit_report_lines(self, lines, report_file=None): + if report_file: + with open(report_file, "a") as f: + f.write("\n".join(lines) + "\n") + return + for line in lines: + print(line, flush=True) + + def _split_benchmark_status(self, status): + status = "" if status is None else str(status) + if status == "ok": + return "OK", "" + if status.startswith("error:"): + return "ERROR", status[len("error:") :].strip() + if status == "mismatch": + return "MISMATCH", "output mismatch vs reference" + if not status: + return "UNKNOWN", "" + return status.upper(), "" + + def _format_benchmark_keys(self, row): + parts = [] + for key in self.keys: + value = row.get(key, "") + parts.append(f"{key}={value}") + return "keys: " + ", ".join(parts) + + def _get_benchmark_e2e_us(self, row, suffix=""): + return getattr(row, f"benchmark_e2e_us{suffix}", -1) + + def _get_benchmark_kernel_us(self, row, suffix=""): + return getattr(row, f"benchmark_kernel_us{suffix}", None) + + def _print_benchmark_results( + self, label, results, report_file=None, shapes_df=None + ): + """Print benchmark results to stdout or append them to a report file.""" + if not results: + self._emit_report_lines([f"{label}: no results"], report_file) + return + results_df = self._benchmark_results_to_df(results, shapes_df=shapes_df) + lines = [f"============= {label} Benchmark Results ============="] + has_kernel_us = ( + not results_df.empty + and "benchmark_kernel_us" in results_df.columns + and results_df["benchmark_kernel_us"].notna().any() + ) + if has_kernel_us: + header = ( + f"{'Shape':<40} | {'Kernel(us)':>10} | {'E2E(us)':>10} | {'Status':>8}" + ) + else: + header = f"{'Shape':<40} | {'E2E(us)':>10} | {'Status':>8}" + lines.append(header) + lines.append("-" * len(header)) + if results_df.empty: + for r in results: + shape_str = r.get("shape", "unknown") + e2e_us = r.get("e2e_us", -1) + status = r.get("status", "unknown") + status_summary, status_detail = self._split_benchmark_status(status) + e2e_str = f"{e2e_us:.2f}" if e2e_us > 0 else "N/A" + lines.append(f"{shape_str:<40} | {e2e_str:>10} | {status_summary:>8}") + if status_detail: + lines.append(f"{'':<40} | {'':>10} | {'reason: ' + status_detail}") + self._emit_report_lines(lines, report_file) + return + for row in results_df.itertuples(index=False): + shape_str = getattr(row, "shape", "unknown") + e2e_us = self._get_benchmark_e2e_us(row) + kernel_us = self._get_benchmark_kernel_us(row) + status = getattr(row, "benchmark_status", "unknown") + status_summary, status_detail = self._split_benchmark_status(status) + e2e_str = f"{e2e_us:.2f}" if e2e_us > 0 else "N/A" + if has_kernel_us: + kernel_str = ( + f"{kernel_us:.2f}" + if kernel_us is not None and pd.notna(kernel_us) and kernel_us > 0 + else "N/A" + ) + lines.append( + f"{shape_str:<40} | {kernel_str:>10} | {e2e_str:>10} | {status_summary:>8}" + ) + else: + lines.append(f"{shape_str:<40} | {e2e_str:>10} | {status_summary:>8}") + lines.append( + self._format_benchmark_keys( + {key: getattr(row, key, "") for key in self.keys} + ) + ) + if status_detail: + lines.append(f"reason: {status_detail}") + self._emit_report_lines(lines, report_file) + + def _print_comparison(self, pre_results, post_results, report_file=None): + """Print comparison to stdout or append it to a report file.""" + if not pre_results or not post_results: + self._emit_report_lines( + ["Cannot print comparison: missing pre or post results"], + report_file, + ) + return + pre_df = self._benchmark_results_to_df(pre_results) + post_df = self._benchmark_results_to_df(post_results) + if pre_df.empty or post_df.empty: + self._emit_report_lines( + ["Cannot print comparison: missing comparable benchmark rows"], + report_file, + ) + return + comparison_df = pre_df.merge( + post_df, + on=self.keys, + how="outer", + suffixes=("_pre", "_post"), + ) + comparison_df["shape"] = comparison_df["shape_pre"] + missing_shape_mask = comparison_df["shape"].isna() | ( + comparison_df["shape"] == "" + ) + comparison_df.loc[missing_shape_mask, "shape"] = comparison_df.loc[ + missing_shape_mask, "shape_post" + ] + lines = ["============= Tune Performance Comparison ============="] + header = f"{'Shape':<40} | {'Pre-E2E(us)':>13} | {'Post-E2E(us)':>14} | {'Speedup':>8} | {'Status':>8}" + lines.append(header) + lines.append("-" * len(header)) + for row in comparison_df.itertuples(index=False): + shape = getattr(row, "shape", "unknown") + pre_us = self._get_benchmark_e2e_us(row, "_pre") + post_us = self._get_benchmark_e2e_us(row, "_post") + post_status = getattr(row, "benchmark_status_post", "error") + if pd.isna(post_status): + pre_str = f"{pre_us:.2f}" if pd.notna(pre_us) and pre_us > 0 else "N/A" + lines.append( + f"{shape:<40} | {pre_str:>13} | {'N/A':>14} | {'N/A':>8} | {'MISS':>8}", + ) + lines.append( + self._format_benchmark_keys( + {key: getattr(row, key, "") for key in self.keys} + ) + ) + continue + status_summary, status_detail = self._split_benchmark_status(post_status) + if pre_us > 0 and post_us > 0: + speedup = pre_us / post_us + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + pre_str = f"{pre_us:.2f}" if pre_us > 0 else "N/A" + post_str = f"{post_us:.2f}" if post_us > 0 else "N/A" + lines.append( + f"{shape:<40} | {pre_str:>13} | {post_str:>14} | {speedup_str:>8} | {status_summary:>8}" + ) + lines.append( + self._format_benchmark_keys( + {key: getattr(row, key, "") for key in self.keys} + ) + ) + if status_detail: + lines.append(f"reason: {status_detail}") + self._emit_report_lines(lines, report_file) + + def _benchmark_results_to_df(self, results, shapes_df=None): + columns = self.keys + [ + "shape", + "benchmark_status", + "benchmark_kernel_us", + "benchmark_e2e_us", + ] + if shapes_df is None: + shapes_df = self.untunedf + if shapes_df is None or len(shapes_df) == 0 or not results: + return pd.DataFrame(columns=columns) + + shapes_df = shapes_df[self.keys].reset_index(drop=True) + limit = min(len(shapes_df), len(results)) + if len(shapes_df) != len(results): + logger.warning( + f"benchmark results count mismatch in {self.name}: " + f"{len(results)} results for {len(shapes_df)} shapes; matching by row order" + ) + + rows = [] + for idx in range(limit): + bench = results[idx] or {} + row = shapes_df.iloc[idx].to_dict() + row["shape"] = bench.get("shape", "") + row["benchmark_status"] = bench.get("status", "unknown") + row["benchmark_kernel_us"] = bench.get("kernel_us", None) + row["benchmark_e2e_us"] = bench.get("e2e_us", -1) + rows.append(row) + return pd.DataFrame(rows, columns=columns) + + def _build_compare_update_plan( + self, pre_results, post_results, threshold_percent, shapes_df=None + ): + pre_df = self._benchmark_results_to_df(pre_results, shapes_df=shapes_df) + post_df = self._benchmark_results_to_df(post_results, shapes_df=shapes_df) + columns = self.keys + [ + "shape", + "pre_us", + "post_us", + "pre_status", + "post_status", + "improvement_pct", + "update", + "update_reason", + ] + if pre_df.empty or post_df.empty: + return pd.DataFrame(columns=columns) + + comparison = pre_df.merge( + post_df, + on=self.keys, + how="outer", + suffixes=("_pre", "_post"), + ) + comparison["shape"] = comparison["shape_pre"] + missing_shape_mask = comparison["shape"].isna() | (comparison["shape"] == "") + comparison.loc[missing_shape_mask, "shape"] = comparison.loc[ + missing_shape_mask, "shape_post" + ] + comparison["pre_us"] = comparison["benchmark_e2e_us_pre"] + comparison["post_us"] = comparison["benchmark_e2e_us_post"] + comparison["pre_status"] = comparison["benchmark_status_pre"] + comparison["post_status"] = comparison["benchmark_status_post"] + + valid = ( + (comparison["pre_status"] == "ok") + & (comparison["post_status"] == "ok") + & (comparison["pre_us"] > 0) + & (comparison["post_us"] > 0) + ) + no_baseline = ( + (comparison["post_status"] == "ok") + & (comparison["post_us"] > 0) + & ~((comparison["pre_status"] == "ok") & (comparison["pre_us"] > 0)) + ) + comparison["improvement_pct"] = ( + (comparison["pre_us"] - comparison["post_us"]) + / comparison["pre_us"] + * 100.0 + ) + comparison.loc[~valid, "improvement_pct"] = float("nan") + comparison["update_reason"] = "skip" + comparison.loc[ + valid & (comparison["improvement_pct"] >= threshold_percent), + "update_reason", + ] = "threshold_met" + comparison.loc[no_baseline, "update_reason"] = "no_baseline" + comparison["update"] = comparison["update_reason"] != "skip" + return comparison[columns] + + def _print_compare_update_plan( + self, + comparison, + threshold_percent, + tuned_file=None, + report_file=None, + apply_updates=True, + ): + if comparison is None or comparison.empty: + self._emit_report_lines( + ["Compare-gated CSV update skipped: no comparable benchmark rows"], + report_file, + ) + return + + lines = [ + ( + "============= Compare-Gated CSV Updates =============" + if apply_updates + else "============= Compare-Gated CSV Update Preview =============" + ) + ] + target_desc = tuned_file if tuned_file else "tuned csv" + lines.append( + ( + f"Threshold: improve >= {threshold_percent:.2f}% to update {target_desc}" + if apply_updates + else f"Threshold: improve >= {threshold_percent:.2f}% would update {target_desc} with --update_improved" + ) + ) + lines.append( + ( + "Rows with no valid pre-run baseline but passing post-run will also update." + if apply_updates + else "Rows with no valid pre-run baseline but passing post-run would also update." + ) + ) + header = f"{'Shape':<40} | {'Pre-E2E':>10} | {'Post-E2E':>10} | {'Improve':>9} | {'Action':>18}" + lines.append(header) + lines.append("-" * len(header)) + for row in comparison.itertuples(index=False): + pre_str = ( + f"{row.pre_us:.2f}" + if pd.notna(row.pre_us) and row.pre_us > 0 + else "N/A" + ) + post_str = ( + f"{row.post_us:.2f}" + if pd.notna(row.post_us) and row.post_us > 0 + else "N/A" + ) + improve_str = ( + f"{row.improvement_pct:.2f}%" + if pd.notna(row.improvement_pct) + else "N/A" + ) + if row.update_reason == "threshold_met": + action = "UPDATE" + elif row.update_reason == "no_baseline": + action = "UPDATE_NO_BASELINE" + else: + action = "SKIP" + lines.append( + f"{row.shape:<40} | {pre_str:>10} | {post_str:>10} | {improve_str:>9} | {action:>18}" + ) + lines.append( + self._format_benchmark_keys( + {key: getattr(row, key, "") for key in self.keys} + ) + ) + pre_summary, pre_detail = self._split_benchmark_status(row.pre_status) + post_summary, post_detail = self._split_benchmark_status(row.post_status) + if pre_detail: + lines.append(f"pre-{pre_summary.lower()}: {pre_detail}") + if post_detail: + lines.append(f"post-{post_summary.lower()}: {post_detail}") + self._emit_report_lines(lines, report_file) + + def _merge_compare_filtered_results(self, base_file, candidate_file, comparison): + old_df = self.get_tuned_gemm_list(base_file) + if not os.path.exists(candidate_file): + return old_df + + candidate_df = self.get_tuned_gemm_list(candidate_file) + if comparison is None or comparison.empty: + return old_df + + improved_keys = set( + comparison.loc[comparison["update"], self.keys] + .astype(str) + .apply(tuple, axis=1) + .tolist() + ) + if not improved_keys: + return old_df + + def key_mask(df): + if df.empty: + return pd.Series([], index=df.index, dtype=bool) + return df[self.keys].astype(str).apply(tuple, axis=1).isin(improved_keys) + + kept_old = old_df[~key_mask(old_df)].copy() + improved_rows = candidate_df[key_mask(candidate_df)].copy() + merged = pd.concat([kept_old, improved_rows], ignore_index=True) + dedup_keys = list(self.keys) + if "_tag" in merged.columns: + merged["_tag"] = merged["_tag"].fillna("") + dedup_keys.append("_tag") + merged = merged.drop_duplicates(subset=dedup_keys, keep="last").reset_index( + drop=True + ) + return merged + + def _run_config_for_shapes(self, args, shapes_df, config_file=None): + original_untunedf = self.untunedf + shapes_df = shapes_df.reset_index(drop=True) + self.untunedf = shapes_df + try: + if config_file is None: + return self.run_config(args) + defaults = self.get_arg_defaults() + env_name = defaults.get("config_env_name") + old_val, old_rebuild = self._set_config_env_for_run_config( + args, config_file=config_file + ) + try: + return self.run_config(args) + finally: + self._restore_config_env(env_name, old_val, old_rebuild) + finally: + self.untunedf = original_untunedf + + def _init_compare_report(self, args, output_file, batch_size, total_batches): + if not args.compare or (total_batches <= 1 and len(self.untunedf) <= 30): + return None + + report_root, _ = os.path.splitext(output_file) + compare_report_file = f"{report_root}.compare.txt" + with open(compare_report_file, "w") as f: + f.write( + f"Compare report for {self.name}\n" + f"Shapes: {len(self.untunedf)}\n" + f"Batch size: {batch_size}\n" + f"Total batches: {total_batches}\n\n" + ) + print(f"Compare results will be written to {compare_report_file}", flush=True) + return compare_report_file + + def _init_compare_candidate_file(self, args, output_file): + if not args.compare: + return None + + candidate_root, candidate_ext = os.path.splitext(output_file) + compare_candidate_file = f"{candidate_root}.candidate{candidate_ext or '.csv'}" + if os.path.exists(output_file): + shutil.copyfile(output_file, compare_candidate_file) + elif os.path.exists(compare_candidate_file): + os.remove(compare_candidate_file) + print( + f"Compare candidate CSV will be written to {compare_candidate_file}", + flush=True, + ) + return compare_candidate_file + + def _emit_compare_batch_header(self, header, report_file=None): + print(header, flush=True) + if report_file: + self._emit_report_lines([header], report_file) + + def _run_compare_benchmark( + self, + args, + batch, + header, + result_label, + report_file=None, + config_file=None, + print_results=True, + ): + self._emit_compare_batch_header(header, report_file) + results = self._run_config_for_shapes(args, batch, config_file=config_file) + if print_results: + self._print_benchmark_results( + result_label, results, report_file=report_file + ) + return results + + def _create_batch_compare_output_file( + self, + args, + results, + output_file, + processed_batches, + compare_candidate_file=None, + ): + fd, batch_compare_output_file = tempfile.mkstemp( + prefix=f"{self.name}_compare_batch_{processed_batches}_", + suffix=".csv", + ) + os.close(fd) + candidate_base_file = ( + compare_candidate_file + if compare_candidate_file and os.path.exists(compare_candidate_file) + else output_file + ) + if os.path.exists(candidate_base_file): + shutil.copyfile(candidate_base_file, batch_compare_output_file) + else: + pd.DataFrame(columns=self.columns).to_csv( + batch_compare_output_file, index=False + ) + self.result_to_csv(results, batch_compare_output_file, not args.all) + if os.path.exists(batch_compare_output_file): + self.sortResults(batch_compare_output_file, args.sort, self.sort_keys) + if compare_candidate_file: + shutil.copyfile(batch_compare_output_file, compare_candidate_file) + return batch_compare_output_file + + def _apply_compare_batch_results( + self, + args, + batch, + results, + batch_pre_tune_results, + output_file, + processed_batches, + total_batches, + compare_report_file=None, + compare_candidate_file=None, + ): + batch_compare_output_file = self._create_batch_compare_output_file( + args, + results, + output_file, + processed_batches, + compare_candidate_file=compare_candidate_file, + ) + try: + batch_header = f"=== Running post-tune benchmark (verification) for batch {processed_batches}/{total_batches} ===" + batch_post_tune_results = self._run_compare_benchmark( + args, + batch, + batch_header, + "Post-tune", + report_file=compare_report_file, + config_file=batch_compare_output_file, + print_results=args.verbose, + ) + batch_compare_plan = self._build_compare_update_plan( + batch_pre_tune_results, + batch_post_tune_results, + args.min_improvement_pct, + shapes_df=batch, + ) + if args.update_improved: + final_df = self._merge_compare_filtered_results( + output_file, + batch_compare_output_file, + batch_compare_plan, + ) + final_df.to_csv(output_file, index=False) + if os.path.exists(output_file): + self.sortResults(output_file, args.sort, self.sort_keys) + self.tunedf = self.get_tuned_gemm_list(output_file) + return batch_post_tune_results, batch_compare_plan + finally: + if os.path.exists(batch_compare_output_file): + os.remove(batch_compare_output_file) + + def _record_completed_compare_batch( + self, + completed_pre_tune_results, + completed_post_tune_results, + compare_plans, + batch_pre_tune_results, + batch_post_tune_results, + batch_compare_plan, + ): + completed_pre_tune_results.extend(batch_pre_tune_results or []) + completed_post_tune_results.extend(batch_post_tune_results or []) + compare_plans.append(batch_compare_plan) + + def _print_compare_summary( + self, + completed_pre_tune_results, + completed_post_tune_results, + compare_plans, + threshold_percent, + tuned_file, + report_file=None, + apply_updates=True, + candidate_file=None, + ): + if not completed_pre_tune_results: + return + + self._print_comparison( + completed_pre_tune_results, + completed_post_tune_results, + report_file=report_file, + ) + combined_compare_plan = ( + pd.concat(compare_plans, ignore_index=True).reset_index(drop=True) + if compare_plans + else pd.DataFrame() + ) + self._print_compare_update_plan( + combined_compare_plan, + threshold_percent, + tuned_file=tuned_file, + report_file=report_file, + apply_updates=apply_updates, + ) + extra_lines = [] + if candidate_file: + extra_lines.append(f"Compare candidate CSV written to {candidate_file}") + if not apply_updates: + extra_lines.append( + "Final tuned CSV was not updated. Re-run with --update_improved to apply improved shapes." + ) + if extra_lines: + self._emit_report_lines(extra_lines, report_file) + if report_file: + print(f"Compare results written to {report_file}", flush=True) + # def run(self, args, fast_mode=False): """tuner run function""" self.pre_process(args) + + # Resolve --run_config: can be False, True (no file), or a file path string. + # Strict semantics: + # --run_config -> tuned kernels using that config file + # --run_config -> default kernels (no config env override) + run_config_file = args.run_config if isinstance(args.run_config, str) else None + + # --run_config with tuned file: load shapes from the tuned CSV. + # --run_config without file: keep shapes from -i (pre_process), run default kernels. + # --compare: always use untuned shapes from -i (pre_process). + if args.run_config and run_config_file: + tunedf = self.get_tuned_gemm_list(run_config_file) + if not tunedf.empty and self.keys[0] in tunedf.columns: + cu = self.get_cu_num() + if "cu_num" in tunedf.columns: + tunedf = tunedf[tunedf["cu_num"] == cu] + self.untunedf = tunedf.drop_duplicates(subset=self.keys).reset_index( + drop=True + ) + print(self.untunedf) output_file = self.get_out_file(args.tune_file) if args.verbose: logger.info(f"args: {args}") + + # --run_config: only run benchmark and exit (no tuning) + if args.run_config: + if self.untunedf.empty: + logger.info("No shapes to benchmark, nothing to run") + return pd.DataFrame() + if run_config_file: + defaults = self.get_arg_defaults() + env_name = defaults.get("config_env_name") + old_val, old_rebuild = self._set_config_env_for_run_config( + args, config_file=run_config_file + ) + try: + print( + "=== Running production operator benchmark (tuned) ===", + flush=True, + ) + results = self.run_config(args) + self._print_benchmark_results("Benchmark (tuned)", results) + finally: + self._restore_config_env(env_name, old_val, old_rebuild) + else: + print( + "=== Running production operator benchmark (default) ===", + flush=True, + ) + results = self.run_config(args) + self._print_benchmark_results("Benchmark (default)", results) + return self.tunedf if self.tunedf is not None else pd.DataFrame() + + # Only include batches that fully completed compare+update in the final summary. + completed_pre_tune_results = [] + completed_post_tune_results = [] + compare_plans = [] + if len(self.untunedf) == 0: # self.update_tflops_bw(args.tune_file) self.sortResults(output_file, args.sort, self.sort_keys) @@ -451,12 +1208,22 @@ def run(self, args, fast_mode=False): return self.tunedf if self.tunedf is not None else pd.DataFrame() batch_size = min(args.batch, len(self.untunedf)) total_batches = (len(self.untunedf) + batch_size - 1) // batch_size + compare_report_file = self._init_compare_report( + args, output_file, batch_size, total_batches + ) + compare_candidate_file = self._init_compare_candidate_file(args, output_file) if args.verbose: logger.info( f"total shapes to be tuned: {len(self.untunedf) }, total_batches: {total_batches}, batch_size: {batch_size}" ) - logger.info(f"results will be written to {output_file}") + if args.compare and not args.update_improved: + logger.info( + f"compare candidate results will be written to {compare_candidate_file}" + ) + else: + logger.info(f"results will be written to {output_file}") processed_batches = 0 + completed_batches = 0 results = [] topk = -1 if fast_mode else 1 self.tune_start_time = time.time() @@ -465,31 +1232,93 @@ def run(self, args, fast_mode=False): for i in range(0, len(self.untunedf), batch_size): batch = self.untunedf.iloc[i : i + batch_size].reset_index(drop=True) processed_batches += 1 + batch_pre_tune_results = None + if args.compare: + batch_header = f"=== Running pre-tune benchmark (batch {processed_batches}/{total_batches}) ===" + batch_pre_tune_results = self._run_compare_benchmark( + args, + batch, + batch_header, + "Pre-tune", + report_file=compare_report_file, + print_results=args.verbose, + ) all_results = self.tune(batch, self.tunedf, args) if all_results: results = self.post_process(all_results, args, topk) - self.result_to_csv(results, output_file, not args.all) + if args.compare: + batch_post_tune_results, batch_compare_plan = ( + self._apply_compare_batch_results( + args, + batch, + results, + batch_pre_tune_results, + output_file, + processed_batches, + total_batches, + compare_report_file=compare_report_file, + compare_candidate_file=compare_candidate_file, + ) + ) + self._record_completed_compare_batch( + completed_pre_tune_results, + completed_post_tune_results, + compare_plans, + batch_pre_tune_results, + batch_post_tune_results, + batch_compare_plan, + ) + else: + self.result_to_csv(results, output_file, not args.all) + completed_batches += 1 logger.info( - f"processed {processed_batches} batches of {total_batches}, Processing Status ====> {round(processed_batches / total_batches,2)*100:.1f}% tuned in {self.name}" + f"processed {completed_batches} batches of {total_batches}, Processing Status ====> {round(completed_batches / total_batches,2)*100:.1f}% tuned in {self.name}" ) else: logger.info( f"tune result is none or all shape is tuned in {args.tune_file}!" ) - self.sortResults(output_file, args.sort, self.sort_keys) + if os.path.exists(output_file): + self.sortResults(output_file, args.sort, self.sort_keys) except KeyboardInterrupt: tuning_status = "Interrupted" logger.error( - f"interrupted by user, tuning stopped, {processed_batches-1} batches processed" + f"interrupted by user, tuning stopped, {completed_batches} batches processed" ) except Exception as e: tuning_status = "Error" logger.error( - f"error in batch {processed_batches} of {total_batches}: {str(e)}", + f"error in batch {processed_batches} of {total_batches} after {completed_batches} completed batches: {str(e)}", exc_info=True, ) finally: - self.tune_summary(tuning_status) + tune_exit = None + summary_exc = None + try: + self.tune_summary(tuning_status) + except SystemExit as e: + tune_exit = e + except Exception as e: + summary_exc = e + logger.error( + f"tune_summary failed (tuning may still have written results): {e}", + exc_info=True, + ) + if args.compare: + self._print_compare_summary( + completed_pre_tune_results, + completed_post_tune_results, + compare_plans, + args.min_improvement_pct, + output_file, + report_file=compare_report_file, + apply_updates=args.update_improved, + candidate_file=compare_candidate_file, + ) + if tune_exit is not None: + raise tune_exit + if summary_exc is not None: + raise summary_exc class GemmCommonTuner(TunerCommon): @@ -541,7 +1370,7 @@ def pre_process(self, args): if args.verbose: logger.info("skiped tuned shapes:") print(self.untunedf[mask]) - self.untunedf = self.untunedf[~mask] + self.untunedf = self.untunedf[~mask].reset_index(drop=True) def calculate(self, results, bpes=(2, 2, 2)): """calculate TFLOPS and bandwidth""" @@ -565,11 +1394,15 @@ def result_to_df(self, results): for el in results: info, time, err_ratio = el keys, kernelId, splitK, kernelName = info - kernelName = ( - "None" - if time == self.INVALID_TIME or time == self.INF_TIME - else self.getKernelName(kernelId) if kernelName == "" else kernelName - ) + # Resolve kernel name for both success and failure (profile CSV / debugging). + # Treat missing/NA like "" so we always look up CK kernel names; otherwise NaN + # would serialize as "Null" via na_rep in to_csv. + need_lookup = kernelName == "" or pd.isna(kernelName) + resolved = self.getKernelName(kernelId) if need_lookup else kernelName + if resolved is None or pd.isna(resolved): + kernelName = "None" + else: + kernelName = str(resolved) tflops, bw = self.calculate(el) key_dict = dict(zip(self.keys, keys)) diff --git a/aiter/utility/mp_tuner.py b/aiter/utility/mp_tuner.py index f19750f26c..5011c8ef06 100644 --- a/aiter/utility/mp_tuner.py +++ b/aiter/utility/mp_tuner.py @@ -290,60 +290,34 @@ def mp_tuner( task_group = [] # dispatch per shape to one pid if shape_grouped: - # Group tasks by info_keys (info[0]) from collections import OrderedDict info_key_groups = OrderedDict() - for task in tasks: - # Extract info_keys from task (task[0] is info, task[0][0] is info_keys) info_keys = task[0][0] if task and len(task) > 0 else None - if info_keys not in info_key_groups: info_key_groups[info_keys] = [] info_key_groups[info_keys].append(task) - # Convert to list of groups task_group = list(info_key_groups.values()) print( f"[Task Grouping] Grouped {len(tasks)} tasks into {len(task_group)} groups by info_keys" ) - # Update in_datas to reflect the actual group sizes - # Each group gets one entry with (group_size, original_data) - new_in_datas = [] - for group_idx, group in enumerate(task_group): - group_size = len(group) - # Use the first task's data configuration, or keep original if within bounds - if group_idx < len(in_datas): - original_data = ( - in_datas[group_idx][1] if len(in_datas[group_idx]) > 1 else None - ) - else: - original_data = ( - in_datas[0][1] if in_datas and len(in_datas[0]) > 1 else None - ) - new_in_datas.append((group_size, original_data)) - - in_datas = new_in_datas - print( - f"[in_datas] Updated to {len(in_datas)} entries with group sizes: {[size for size, _ in in_datas]}" - ) + # in_datas already has one entry per shape from the tuner; + # just verify cardinality matches and use it directly. + assert len(task_group) == len( + in_datas + ), f"shape_grouped: group count ({len(task_group)}) != in_datas count ({len(in_datas)})" + ref_data_index = list(range(len(task_group))) else: task_group = tasks + import numpy as np - # to get index of input data for task_group - import numpy as np - - ref_data_index = [i for i in range(len(in_datas))] - if not shape_grouped: cumulative = np.cumsum([size for size, _ in in_datas]) ref_data_index = np.searchsorted( cumulative, np.arange(len(task_group)), side="right" ) - else: - # For shape_grouped, each group directly maps to its in_data entry - ref_data_index = list(range(len(task_group))) print(f"Distributing {len(task_group)} task groups across {mp_num} GPUs") @@ -410,7 +384,8 @@ def add_dummy_result(k, results_list): while remaining_tasks: completed_this_round = [] dummy_failed_tasks = [] - timeout_count_this_round = 0 # Track timeouts in this round + consecutive_timeouts = 0 + half_gpu = max(1, (mp_num + 1) // 2) for k, async_result in remaining_tasks: try: @@ -430,6 +405,7 @@ def add_dummy_result(k, results_list): # Task completed successfully result_dict[k] = task_result completed_this_round.append((k, async_result)) + consecutive_timeouts = 0 elapsed = time.time() - task_start_times[k] if verbose: print( @@ -442,7 +418,7 @@ def add_dummy_result(k, results_list): elapsed = time.time() - task_start_times[k] if elapsed > timeout: - timeout_count_this_round += 1 + consecutive_timeouts += 1 error_msg = f"[!] Task {k} timed out after {elapsed:.1f}s (limit: {timeout}s) - likely GPU hang or infinite loop" print(error_msg) @@ -459,13 +435,15 @@ def add_dummy_result(k, results_list): # Trigger pool restart for timeout (similar to crash) pool_restart_needed = True - # If mp_num tasks timed out, all GPUs are likely stuck - restart immediately - if timeout_count_this_round >= mp_num: + # If half the GPUs worth of consecutive timeouts, pool is in bad shape + if consecutive_timeouts >= half_gpu: print( - f"\n[!] {timeout_count_this_round} tasks timed out (all {mp_num} GPUs likely stuck)" + f"\n[!] {consecutive_timeouts} consecutive tasks timed out (>= {half_gpu}/{mp_num} GPUs likely stuck)" ) print("[!] Triggering immediate pool restart...\n") break + else: + consecutive_timeouts = 0 except Exception as e: # Check if it's a process crash (segfault, memory fault, etc.) @@ -473,11 +451,10 @@ def add_dummy_result(k, results_list): # Special handling for KeyError (PID mapping issue) is_mapping_error = error_type == "KeyError" - + # not restart as this is not root use if is_mapping_error: - error_msg = f"[Mapping Error] Task {k} - Process PID not in GPU map (triggering pool restart): {error_type} - {e}" + error_msg = f"[Mapping Error] Task {k} - Process PID not in GPU map: {error_type} - {e}" dummy_failed_tasks.append((k, "mapping error")) - # pool_restart_needed = True elif error_type == "AcceleratorError": # GPU fault (e.g. illegal memory access): worker returns exception instead of # hanging. Unlike hang->timeout, the faulting worker may stay alive and accept @@ -497,7 +474,6 @@ def add_dummy_result(k, results_list): break else: error_msg = f"[Failed] Task {k} failed with {error_type}: {e}" - failed_tasks.append((k, "timeout")) failed_tasks.append((k, "unknown error")) # Always record a dummy result so reconstruction never sees an empty list @@ -523,16 +499,18 @@ def add_dummy_result(k, results_list): if pool_restart_needed and remaining_tasks: if verbose: print(f"\n{'='*60}") - print("? Pool restart needed due to crash. Restarting pool...") - print(f"Remaining tasks: {len(remaining_tasks)}") - print(f"{'='*60}\n") + print( + "? Pool restart needed due to crash. Restarting pool...", flush=True + ) + print(f"Remaining tasks: {len(remaining_tasks)}", flush=True) + print(f"{'='*60}\n", flush=True) # Terminate old pool try: pool.terminate() pool.join() except Exception as e: - print(f"Warning: Error during pool termination: {e}") + print(f"Warning: Error during pool termination: {e}", flush=True) # Create new pool pool = mp.Pool(processes=parallel_num) @@ -554,7 +532,8 @@ def add_dummy_result(k, results_list): # Reset pool restart flag pool_restart_needed = False print( - f"Pool restarted. Continuing with {len(remaining_tasks)} remaining tasks...\n" + f"Pool restarted. Continuing with {len(remaining_tasks)} remaining tasks...\n", + flush=True, ) # Small sleep to avoid busy waiting diff --git a/csrc/ck_batched_gemm_a8w8/README.md b/csrc/ck_batched_gemm_a8w8/README.md index 11cc7e1bda..1b7e00338a 100644 --- a/csrc/ck_batched_gemm_a8w8/README.md +++ b/csrc/ck_batched_gemm_a8w8/README.md @@ -106,6 +106,53 @@ If you have built batched_gemm_a8w8 kernels before tuning new GEMM shapes, pleas --all ``` +#### `--run_config [TUNED_CSV]` +- **Type**: Optional argument +- **Default**: disabled +- **Description**: Run production-operator benchmark only and exit (no tuning). + - `--run_config /path/to/tuned.csv`: read shapes from that tuned CSV and run tuned kernels from that file. + - `--run_config` (no path): read shapes from `-i/--untune_file` and run default kernels. + +**Examples**: +```bash +# benchmark tuned kernels from specified tuned config +python3 csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py \ + --run_config aiter/configs/a8w8_tuned_batched_gemm.csv + +# benchmark default kernels using shapes from -i +python3 csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py \ + -i aiter/configs/a8w8_untuned_batched_gemm.csv --run_config +``` + +#### `--compare` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: Run pre-tune and post-tune production benchmark, print compare results, and keep a compare candidate CSV. + - Pre-tune reads shapes from `-i/--untune_file`. + - Post-tune uses configs written to `.candidate.csv` during the compare run. + - The final tuned CSV is only updated when `--update_improved` is also set. + - Shapes with no valid pre-run baseline can still update when the post-tune benchmark passes. + +**Example**: +```bash +--compare +``` + +#### `--update_improved` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: With `--compare`, update the final tuned CSV for shapes improved by at least `--min_improvement_pct`, or for shapes with no valid pre-run baseline when the post-tune benchmark passes. + +**Example**: +```bash +--compare --update_improved +``` + +#### `--min_improvement_pct` +- **Type**: Float +- **Default**: `3.0` +- **Description**: With `--compare --update_improved`, the minimum percentage improvement required before a compared result replaces the final tuned CSV entry when both pre/post benchmarks are valid. Shapes with no valid pre-run baseline but passing post-tune are still allowed to update. + ### Profiling Configuration #### `--warmup` diff --git a/csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py b/csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py index ca3a59c7ee..27797405e4 100644 --- a/csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py +++ b/csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py @@ -61,11 +61,53 @@ class BatchedGemma8W8Tuner(GemmCommonTuner): "errRatio": 0.05, "batch": 100, "profile_file": "", + "config_env_name": "AITER_CONFIG_A8W8_BATCHED_GEMM", } + def _clear_op_caches(self): + from aiter.ops.batched_gemm_op_a8w8 import get_CKBatchedGEMM_config + + get_CKBatchedGEMM_config.cache_clear() + if hasattr(get_CKBatchedGEMM_config, "ck_batched_gemm_dict"): + del get_CKBatchedGEMM_config.ck_batched_gemm_dict + def _setup_specific_arguments(self): pass + def run_config(self, args): + from aiter.ops.batched_gemm_op_a8w8 import batched_gemm_a8w8 + from aiter.test_common import run_perftest, checkAllclose + + untunedf = self.untunedf + results = [] + for i in range(len(untunedf)): + B = int(untunedf.loc[i, "B"]) + M = int(untunedf.loc[i, "M"]) + N = int(untunedf.loc[i, "N"]) + K = int(untunedf.loc[i, "K"]) + shape_str = f"({B}, {M}, {N}, {K})" + try: + x, weight, x_scale, w_scale, out = generate_data(B, M, N, K) + out, us = run_perftest( + batched_gemm_a8w8, + x, + weight, + x_scale, + w_scale, + out, + num_warmup=args.warmup, + num_iters=args.iters, + ) + ref = run_torch(x, weight, x_scale, w_scale) + err_ratio = checkAllclose(out, ref, msg=f"run_config {shape_str}") + status = "ok" if err_ratio <= args.errRatio else "mismatch" + results.append({"shape": shape_str, "e2e_us": us, "status": status}) + except Exception as e: + results.append( + {"shape": shape_str, "e2e_us": -1, "status": f"error:{e}"} + ) + return results + def calculate(self, results, bpes=(1, 1, 2)): info, time, err_ratio = results if time == -1: @@ -95,10 +137,9 @@ def tune( tunedf, args, ): - issorted = args.sort useSplitK = args.splitK mp_num = args.mp - shape_grouped = False + shape_grouped = args.shape_grouped errRatio = args.errRatio cu_num = self.get_cu_num() task = [] @@ -116,8 +157,8 @@ def tune( ) # kernelId, splitK, time = tune_batched_gemm(B, M, N, K, useSplitK) total_kernel_nums = 0 - for i in range(kernels_num): - kernel = kernels_list[i] + for kid in range(kernels_num): + kernel = kernels_list[kid] maxsplitK = ( aiter.compute_batched_gemm_SplitK( M, @@ -131,7 +172,7 @@ def tune( else 0 ) for splitK in range(maxsplitK + 1): - info = ((cu_num, B, M, N, K), i, splitK, "") + info = ((cu_num, B, M, N, K), kid, splitK, "") task.append( ( info, @@ -140,7 +181,7 @@ def tune( kernel_instance_test, ( [0, 1, 2, 3, 4], - i, + kid, splitK, ), # [0, 1, 2, 3, 4] is index of paramters for kernel_instance_test in generate_data { @@ -160,7 +201,6 @@ def tune( tasks_data.append((total_kernel_nums, ())) ret = [] if task: - shape_grouped = False ret = mp_tuner( task, tasks_data, diff --git a/csrc/ck_batched_gemm_bf16/README.md b/csrc/ck_batched_gemm_bf16/README.md index e5b1590be3..714435e3c0 100644 --- a/csrc/ck_batched_gemm_bf16/README.md +++ b/csrc/ck_batched_gemm_bf16/README.md @@ -106,6 +106,53 @@ If you have built batched_gemm_bf16 kernels before tuning new GEMM shapes, pleas --all ``` +#### `--run_config [TUNED_CSV]` +- **Type**: Optional argument +- **Default**: disabled +- **Description**: Run production-operator benchmark only and exit (no tuning). + - `--run_config /path/to/tuned.csv`: read shapes from that tuned CSV and run tuned kernels from that file. + - `--run_config` (no path): read shapes from `-i/--untune_file` and run default kernels. + +**Examples**: +```bash +# benchmark tuned kernels from specified tuned config +python3 csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py \ + --run_config aiter/configs/bf16_tuned_batched_gemm.csv + +# benchmark default kernels using shapes from -i +python3 csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py \ + -i aiter/configs/bf16_untuned_batched_gemm.csv --run_config +``` + +#### `--compare` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: Run pre-tune and post-tune production benchmark, print compare results, and keep a compare candidate CSV. + - Pre-tune reads shapes from `-i/--untune_file`. + - Post-tune uses configs written to `.candidate.csv` during the compare run. + - The final tuned CSV is only updated when `--update_improved` is also set. + - Shapes with no valid pre-run baseline can still update when the post-tune benchmark passes. + +**Example**: +```bash +--compare +``` + +#### `--update_improved` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: With `--compare`, update the final tuned CSV for shapes improved by at least `--min_improvement_pct`, or for shapes with no valid pre-run baseline when the post-tune benchmark passes. + +**Example**: +```bash +--compare --update_improved +``` + +#### `--min_improvement_pct` +- **Type**: Float +- **Default**: `3.0` +- **Description**: With `--compare --update_improved`, the minimum percentage improvement required before a compared result replaces the final tuned CSV entry when both pre/post benchmarks are valid. Shapes with no valid pre-run baseline but passing post-tune are still allowed to update. + ### Profiling Configuration #### `--warmup` diff --git a/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py b/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py index e03891aa12..74e5fa936c 100644 --- a/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py +++ b/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py @@ -43,11 +43,51 @@ class BatchedGemmBf16Tuner(GemmCommonTuner): "errRatio": 0.05, "batch": 100, "profile_file": "", + "config_env_name": "AITER_CONFIG_BF16_BATCHED_GEMM", } + def _clear_op_caches(self): + from aiter.ops.batched_gemm_op_bf16 import get_CKBatchedGEMM_config + + get_CKBatchedGEMM_config.cache_clear() + if hasattr(get_CKBatchedGEMM_config, "ck_batched_gemm_dict"): + del get_CKBatchedGEMM_config.ck_batched_gemm_dict + def _setup_specific_arguments(self): pass + def run_config(self, args): + from aiter.ops.batched_gemm_op_bf16 import batched_gemm_bf16 + from aiter.test_common import run_perftest, checkAllclose + + untunedf = self.untunedf + results = [] + for i in range(len(untunedf)): + B = int(untunedf.loc[i, "B"]) + M = int(untunedf.loc[i, "M"]) + N = int(untunedf.loc[i, "N"]) + K = int(untunedf.loc[i, "K"]) + shape_str = f"({B}, {M}, {N}, {K})" + try: + x, weight, out = generate_data(B, M, N, K) + out, us = run_perftest( + batched_gemm_bf16, + x, + weight, + out, + num_warmup=args.warmup, + num_iters=args.iters, + ) + ref = run_torch(x, weight) + err_ratio = checkAllclose(out, ref, msg=f"run_config {shape_str}") + status = "ok" if err_ratio <= args.errRatio else "mismatch" + results.append({"shape": shape_str, "e2e_us": us, "status": status}) + except Exception as e: + results.append( + {"shape": shape_str, "e2e_us": -1, "status": f"error:{e}"} + ) + return results + def calculate(self, results, bpes=(2, 2, 2)): info, time, err_ratio = results if time == -1: @@ -76,10 +116,9 @@ def tune( tunedf, args, ): - issorted = args.sort useSplitK = args.splitK mp_num = args.mp - shape_grouped = False + shape_grouped = args.shape_grouped errRatio = args.errRatio cu_num = self.get_cu_num() @@ -95,8 +134,8 @@ def tune( print(f"tuning B:{B}, M:{M}, N:{N}, K:{K}") # kernelId, splitK, time = tune_batched_gemm(B, M, N, K, useSplitK) total_kernel_nums = 0 - for i in range(kernels_num): - kernel = kernels_list[i] + for kid in range(kernels_num): + kernel = kernels_list[kid] maxsplitK = ( aiter.compute_batched_gemm_SplitK( M, @@ -110,7 +149,7 @@ def tune( else 0 ) for splitK in range(maxsplitK + 1): - info = ((cu_num, B, M, N, K), i, splitK, "") + info = ((cu_num, B, M, N, K), kid, splitK, "") task.append( ( info, @@ -119,7 +158,7 @@ def tune( run_batched_gemm, ( [0, 1, 2], - i, + kid, splitK, ), # [0, 1, 2] is index of paramters for run_batched_gemm in generate_data { diff --git a/csrc/ck_gemm_a4w4_blockscale/README.md b/csrc/ck_gemm_a4w4_blockscale/README.md index 25b270aeb2..f95cbbbba2 100755 --- a/csrc/ck_gemm_a4w4_blockscale/README.md +++ b/csrc/ck_gemm_a4w4_blockscale/README.md @@ -107,6 +107,53 @@ If you have built gemm_a4w4 kernels before tuning new GEMM shapes, please add `A --all ``` +#### `--run_config [TUNED_CSV]` +- **Type**: Optional argument +- **Default**: disabled +- **Description**: Run production-operator benchmark only and exit (no tuning). + - `--run_config /path/to/tuned.csv`: read shapes from that tuned CSV and run tuned kernels from that file. + - `--run_config` (no path): read shapes from `-i/--untune_file` and run default kernels. + +**Examples**: +```bash +# benchmark tuned kernels from specified tuned config +python3 csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py \ + --run_config aiter/configs/a4w4_blockscale_tuned_gemm.csv + +# benchmark default kernels using shapes from -i +python3 csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py \ + -i aiter/configs/a4w4_blockscale_untuned_gemm.csv --run_config +``` + +#### `--compare` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: Run pre-tune and post-tune production benchmark, print compare results, and keep a compare candidate CSV. + - Pre-tune reads shapes from `-i/--untune_file`. + - Post-tune uses configs written to `.candidate.csv` during the compare run. + - The final tuned CSV is only updated when `--update_improved` is also set. + - Shapes with no valid pre-run baseline can still update when the post-tune benchmark passes. + +**Example**: +```bash +--compare +``` + +#### `--update_improved` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: With `--compare`, update the final tuned CSV for shapes improved by at least `--min_improvement_pct`, or for shapes with no valid pre-run baseline when the post-tune benchmark passes. + +**Example**: +```bash +--compare --update_improved +``` + +#### `--min_improvement_pct` +- **Type**: Float +- **Default**: `3.0` +- **Description**: With `--compare --update_improved`, the minimum percentage improvement required before a compared result replaces the final tuned CSV entry when both pre/post benchmarks are valid. Shapes with no valid pre-run baseline but passing post-tune are still allowed to update. + ### Profiling Configuration #### `--warmup` diff --git a/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py b/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py index 9d4aed1c4a..52076538cd 100755 --- a/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py +++ b/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import os import pandas as pd @@ -132,11 +132,64 @@ class GemmA4W4BlockScaleTuner(GemmCommonTuner): **GemmCommonTuner.ARG_DEFAULTS, "tune_file": f"{AITER_CONFIG_GEMM_A4W4}", "untune_file": "aiter/configs/a4w4_blockscale_untuned_gemm.csv", + "config_env_name": "AITER_CONFIG_GEMM_A4W4", } + def _clear_op_caches(self): + from aiter.ops.gemm_op_a4w4 import get_GEMM_config + + get_GEMM_config.cache_clear() + if hasattr(get_GEMM_config, "gemm_dict"): + del get_GEMM_config.gemm_dict + def _setup_specific_arguments(self): pass + def run_config(self, args): + from aiter.ops.gemm_op_a4w4 import gemm_a4w4 + from aiter.test_common import run_perftest, checkAllclose + + untunedf = self.untunedf + results = [] + for i in range(len(untunedf)): + row = untunedf.iloc[i] + M = int(row["M"]) + N = int(row["N"]) + K = int(row["K"]) + shape_str = f"({M}, {N}, {K})" + try: + ( + x, + w, + x_scales, + w_scales, + w_shuffle, + x_scales_shuffle, + w_scales_shuffle, + out_ck, + bias_f32, + ) = generate_data(M, N, K, 0) + out, us = run_perftest( + gemm_a4w4, + x, + w_shuffle, + x_scales_shuffle, + w_scales_shuffle, + num_warmup=args.warmup, + num_iters=args.iters, + ) + ref = run_torch(x, w, x_scales, w_scales, dtypes.bf16) + err_ratio = checkAllclose( + out[:M].to(dtypes.bf16), ref, msg=f"run_config {shape_str}" + ) + status = "ok" if err_ratio <= args.errRatio else "mismatch" + results.append({"shape": shape_str, "e2e_us": us, "status": status}) + except Exception as e: + results.append( + {"shape": shape_str, "e2e_us": -1, "status": f"error:{e}"} + ) + return results + def calculate(self, results, bpes=(1 / 2, 1 / 2, 2)): return super().calculate(results, bpes=bpes) @@ -158,7 +211,8 @@ def get_asm_kernels(self, file): return kernel_dict def getKernelName(self, kernelId): - if kernelId < 0 or kernelId > len(kernels_list): + # kernels_list is a dict keyed by kernel index; do not use len() bounds only. + if kernelId is None or kernelId < 0 or kernelId not in kernels_list: return None return kernels_list[kernelId].name @@ -168,10 +222,9 @@ def tune( tunedf, args, ): - issorted = args.sort useSplitK = args.splitK mp_num = args.mp - shape_grouped = False + shape_grouped = args.shape_grouped errRatio = args.errRatio from aiter.jit.utils.chip_info import get_gfx @@ -180,7 +233,7 @@ def tune( return [] gpu = torch.cuda.current_device() device_properties = torch.cuda.get_device_properties(gpu) - cu_num = device_properties.multi_processor_count + cu_num = int(device_properties.multi_processor_count) task = [] tasks_in_data = [] @@ -189,16 +242,16 @@ def tune( gemm_asm_data_idx = [0, 4, 5, 6, 7, 8] torch_data_idx = [0, 1, 2, 3] seed = 1000 - for i in range(len(untunedf)): - M = untunedf.loc[i, "M"] - N = untunedf.loc[i, "N"] - K = untunedf.loc[i, "K"] + for shape_idx in range(len(untunedf)): + row = untunedf.iloc[shape_idx] + # Native int keys so post_process grouping matches single-shape runs (no np.int64 vs int split). + M, N, K = int(row["M"]), int(row["N"]), int(row["K"]) total_kernel_nums = 0 seed = seed + 1 - for i in range(ck_kernels_num): - kernel = kernels_list[i] + for kernel_idx in range(ck_kernels_num): + kernel = kernels_list[kernel_idx] maxsplitK = ( aiter.compute_gemm_SplitK( M, @@ -212,7 +265,7 @@ def tune( else 0 ) for splitK in range(maxsplitK + 1): - info = ((cu_num, M, N, K), i, splitK, "") + info = ((cu_num, M, N, K), kernel_idx, splitK, "") task.append( ( info, @@ -221,7 +274,7 @@ def tune( run_gemm_a4w4_blockscale, ( gemm_a4w4_data_idx, - i, + kernel_idx, splitK, ), { diff --git a/csrc/ck_gemm_a8w8/README.md b/csrc/ck_gemm_a8w8/README.md index 6388dbc051..857d193e99 100644 --- a/csrc/ck_gemm_a8w8/README.md +++ b/csrc/ck_gemm_a8w8/README.md @@ -105,6 +105,53 @@ If you have built gemm_a8w8 kernels before tuning new GEMM shapes, please add `A --all ``` +#### `--run_config [TUNED_CSV]` +- **Type**: Optional argument +- **Default**: disabled +- **Description**: Run production-operator benchmark only and exit (no tuning). + - `--run_config /path/to/tuned.csv`: read shapes from that tuned CSV and run tuned kernels from that file. + - `--run_config` (no path): read shapes from `-i/--untune_file` and run default kernels. + +**Examples**: +```bash +# benchmark tuned kernels from specified tuned config +python3 csrc/ck_gemm_a8w8/gemm_a8w8_tune.py \ + --run_config aiter/configs/a8w8_tuned_gemm.csv + +# benchmark default kernels using shapes from -i +python3 csrc/ck_gemm_a8w8/gemm_a8w8_tune.py \ + -i aiter/configs/a8w8_untuned_gemm.csv --run_config +``` + +#### `--compare` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: Run pre-tune and post-tune production benchmark, print compare results, and keep a compare candidate CSV. + - Pre-tune reads shapes from `-i/--untune_file`. + - Post-tune uses configs written to `.candidate.csv` during the compare run. + - The final tuned CSV is only updated when `--update_improved` is also set. + - Shapes with no valid pre-run baseline can still update when the post-tune benchmark passes. + +**Example**: +```bash +--compare +``` + +#### `--update_improved` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: With `--compare`, update the final tuned CSV for shapes improved by at least `--min_improvement_pct`, or for shapes with no valid pre-run baseline when the post-tune benchmark passes. + +**Example**: +```bash +--compare --update_improved +``` + +#### `--min_improvement_pct` +- **Type**: Float +- **Default**: `3.0` +- **Description**: With `--compare --update_improved`, the minimum percentage improvement required before a compared result replaces the final tuned CSV entry when both pre/post benchmarks are valid. Shapes with no valid pre-run baseline but passing post-tune are still allowed to update. + ### Profiling Configuration #### `--warmup` diff --git a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py index a1bd121146..85df81ea54 100644 --- a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py +++ b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py @@ -99,6 +99,7 @@ class GemmA8W8Tuner(GemmCommonTuner): "errRatio": 0.05, "batch": 100, "profile_file": "", + "config_env_name": "AITER_CONFIG_GEMM_A8W8", } def getKernelName(self, kernelId): @@ -106,22 +107,72 @@ def getKernelName(self, kernelId): return None return kernels_list[kernelId].name + def _clear_op_caches(self): + from aiter.ops.gemm_op_a8w8 import get_GEMM_config_with_quant_type + + get_GEMM_config_with_quant_type.cache_clear() + if hasattr(get_GEMM_config_with_quant_type, "file_cache"): + get_GEMM_config_with_quant_type.file_cache.clear() + def _setup_specific_arguments(self): pass def calculate(self, results, bpes=(1, 1, 2)): return super().calculate(results, bpes=(1, 1, 2)) + def run_config(self, args): + from aiter.ops.gemm_op_a8w8 import gemm_a8w8 + from aiter.test_common import run_perftest, checkAllclose + + untunedf = self.untunedf + results = [] + for i in range(len(untunedf)): + M = int(untunedf.loc[i, "M"]) + N = int(untunedf.loc[i, "N"]) + K = int(untunedf.loc[i, "K"]) + q_dtype_w = untunedf.loc[i, "q_dtype_w"] + shape_str = f"({M}, {N}, {K}, {q_dtype_w})" + try: + x, weight, x_scale, w_scale, out = generate_data( + M, N, K, 0, dtypes.bf16, eval(q_dtype_w) + ) + out, us = run_perftest( + gemm_a8w8, + x, + weight, + x_scale, + w_scale, + num_warmup=args.warmup, + num_iters=args.iters, + ) + ref = gemm_a8w8_ref( + x, + weight, + x_scale, + w_scale, + dtype=dtypes.bf16, + q_dtype_w=eval(q_dtype_w), + ) + err_ratio = checkAllclose( + out.to(dtypes.bf16), ref, msg=f"run_config {shape_str}" + ) + status = "ok" if err_ratio <= args.errRatio else "mismatch" + results.append({"shape": shape_str, "e2e_us": us, "status": status}) + except Exception as e: + results.append( + {"shape": shape_str, "e2e_us": -1, "status": f"error:{e}"} + ) + return results + def tune( self, untunedf, tunedf, args, ): - issorted = args.sort useSplitK = args.splitK mp_num = args.mp - shape_grouped = False + shape_grouped = args.shape_grouped errRatio = args.errRatio cu_num = self.get_cu_num() diff --git a/csrc/ck_gemm_a8w8_blockscale/README.md b/csrc/ck_gemm_a8w8_blockscale/README.md index ec853fcf5c..a514ee7f43 100755 --- a/csrc/ck_gemm_a8w8_blockscale/README.md +++ b/csrc/ck_gemm_a8w8_blockscale/README.md @@ -104,6 +104,53 @@ If you have built gemm_a8w8 kernels before tuning new GEMM shapes, please add `A --all ``` +#### `--run_config [TUNED_CSV]` +- **Type**: Optional argument +- **Default**: disabled +- **Description**: Run production-operator benchmark only and exit (no tuning). + - `--run_config /path/to/tuned.csv`: read shapes from that tuned CSV and run tuned kernels from that file. + - `--run_config` (no path): read shapes from `-i/--untune_file` and run default kernels. + +**Examples**: +```bash +# benchmark tuned kernels from specified tuned config +python3 csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py \ + --run_config aiter/configs/a8w8_blockscale_tuned_gemm.csv + +# benchmark default kernels using shapes from -i +python3 csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py \ + -i aiter/configs/a8w8_blockscale_untuned_gemm.csv --run_config +``` + +#### `--compare` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: Run pre-tune and post-tune production benchmark, print compare results, and keep a compare candidate CSV. + - Pre-tune reads shapes from `-i/--untune_file`. + - Post-tune uses configs written to `.candidate.csv` during the compare run. + - The final tuned CSV is only updated when `--update_improved` is also set. + - Shapes with no valid pre-run baseline can still update when the post-tune benchmark passes. + +**Example**: +```bash +--compare +``` + +#### `--update_improved` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: With `--compare`, update the final tuned CSV for shapes improved by at least `--min_improvement_pct`, or for shapes with no valid pre-run baseline when the post-tune benchmark passes. + +**Example**: +```bash +--compare --update_improved +``` + +#### `--min_improvement_pct` +- **Type**: Float +- **Default**: `3.0` +- **Description**: With `--compare --update_improved`, the minimum percentage improvement required before a compared result replaces the final tuned CSV entry when both pre/post benchmarks are valid. Shapes with no valid pre-run baseline but passing post-tune are still allowed to update. + ### Profiling Configuration #### `--warmup` diff --git a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py index 09c2b6457b..8244c36a6c 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py +++ b/csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py @@ -172,6 +172,7 @@ class GemmA8W8BlockScaleTuner(GemmCommonTuner): "errRatio": 0.05, "batch": 100, "profile_file": "", # for both results + "config_env_name": "AITER_CONFIG_GEMM_A8W8_BLOCKSCALE", } def __init__(self, name, keys, resultList, description=""): @@ -181,6 +182,13 @@ def __init__(self, name, keys, resultList, description=""): super().__init__(name, keys, resultList, description) + def _clear_op_caches(self): + from aiter.ops.gemm_op_a8w8 import get_GEMM_config_with_quant_type + + get_GEMM_config_with_quant_type.cache_clear() + if hasattr(get_GEMM_config_with_quant_type, "file_cache"): + get_GEMM_config_with_quant_type.file_cache.clear() + def _setup_specific_arguments(self): """ Setup specific arguments for the tuner. @@ -264,8 +272,7 @@ def get_gemm_a8w8_blockscale_cktile_tune_task( seed, preshuffleB, block_per_cu, - num_warmup, - num_iters, + run_kwargs, ): cu_num, M, N, K = info_keys # kernel_list = candidate_kernels_bpreshuffle_cktile_dict if preshuffleB else candidate_kernels_cktile_dict @@ -311,10 +318,7 @@ def get_gemm_a8w8_blockscale_cktile_tune_task( splitK, preshuffleB, ), - { - "num_warmup": num_warmup, - "num_iters": num_iters, - }, + dict(run_kwargs), run_torch, ( ref_data_idx, @@ -335,8 +339,7 @@ def get_gemm_a8w8_blockscale_tune_task( useSplitK, seed, preshuffleB, - num_warmup, - num_iters, + run_kwargs, ): cu_num, M, N, K = info_keys kernel_list = ( @@ -376,10 +379,7 @@ def get_gemm_a8w8_blockscale_tune_task( splitK, preshuffleB, ), - { - "num_warmup": num_warmup, - "num_iters": num_iters, - }, + dict(run_kwargs), run_torch, ( ref_data_idx, @@ -394,14 +394,64 @@ def get_gemm_a8w8_blockscale_tune_task( ) return tasks_ck + def run_config(self, args): + from aiter.ops.gemm_op_a8w8 import ( + gemm_a8w8_blockscale, + gemm_a8w8_blockscale_bpreshuffle, + ) + from aiter.test_common import run_perftest, checkAllclose + + is_preshuffle = args.preshuffle + untunedf = self.untunedf + run_kwargs = { + "num_warmup": args.warmup, + "num_iters": args.iters, + } + results = [] + for i in range(len(untunedf)): + M = int(untunedf.loc[i, "M"]) + N = int(untunedf.loc[i, "N"]) + K = int(untunedf.loc[i, "K"]) + shape_str = f"({M}, {N}, {K})" + try: + x, weight, x_scale, w_scale, out, weight_shuffle, x_scale_t, _ = ( + generate_data(M, N, K, 0) + ) + if is_preshuffle: + out, us = run_perftest( + gemm_a8w8_blockscale_bpreshuffle, + x, + weight_shuffle, + x_scale_t, + w_scale, + **run_kwargs, + ) + else: + out, us = run_perftest( + gemm_a8w8_blockscale, + x, + weight, + x_scale, + w_scale, + **run_kwargs, + ) + ref = run_torch(x, weight, x_scale, w_scale) + err_ratio = checkAllclose(out, ref, msg=f"run_config {shape_str}") + status = "ok" if err_ratio <= args.errRatio else "mismatch" + results.append({"shape": shape_str, "e2e_us": us, "status": status}) + except Exception as e: + results.append( + {"shape": shape_str, "e2e_us": -1, "status": f"error:{e}"} + ) + return results + def get_gemm_a8w8_blockscale_asm_tune_task( self, info_keys, useSplitK, seed, preshuffleB, - num_warmup, - num_iters, + run_kwargs, ): cu_num, M, N, K = info_keys asm_kernel_list_csv = ( @@ -447,10 +497,7 @@ def get_gemm_a8w8_blockscale_asm_tune_task( splitK, preshuffleB, ), - { - "num_warmup": num_warmup, - "num_iters": num_iters, - }, + dict(run_kwargs), run_torch, ( ref_data_idx, @@ -475,12 +522,14 @@ def tune( useSplitK = args.splitK mp_num = args.mp isPreshuffleB = args.preshuffle - shape_grouped = False + shape_grouped = args.shape_grouped errRatio = args.errRatio - num_warmup = args.warmup - num_iters = args.iters block_per_cu = args.blockPerCu cu_num = self.get_cu_num() + run_kwargs = { + "num_warmup": args.warmup, + "num_iters": args.iters, + } task = [] tasks_data = [] # [(kernel_nums, datas)] seed = 10000 @@ -489,7 +538,7 @@ def tune( N = untunedf.loc[i, "N"] K = untunedf.loc[i, "K"] seed = seed + 1 - total_kernel_nums = 0 + prev_task_count = len(task) info_keys = (cu_num, M, N, K) lib = args.libtype if lib in ("ck", "both", "all"): @@ -499,8 +548,7 @@ def tune( useSplitK, seed, isPreshuffleB, - num_warmup, - num_iters, + run_kwargs, ) ) if lib in ("cktile", "both", "all"): @@ -511,8 +559,7 @@ def tune( seed, isPreshuffleB, block_per_cu, - num_warmup, - num_iters, + run_kwargs, ) ) if lib in ("asm", "all"): @@ -522,13 +569,12 @@ def tune( useSplitK, seed, isPreshuffleB, - num_warmup, - num_iters, + run_kwargs, ) ) - total_kernel_nums = len(task) + shape_kernel_nums = len(task) - prev_task_count - tasks_data.append((total_kernel_nums, ())) + tasks_data.append((shape_kernel_nums, ())) ret = [] if task: ret = mp_tuner( diff --git a/csrc/ck_gemm_a8w8_blockscale_bpreshuffle/README.md b/csrc/ck_gemm_a8w8_blockscale_bpreshuffle/README.md index e88b50477f..4d4c6c183a 100755 --- a/csrc/ck_gemm_a8w8_blockscale_bpreshuffle/README.md +++ b/csrc/ck_gemm_a8w8_blockscale_bpreshuffle/README.md @@ -119,6 +119,53 @@ The tuning uses `csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py` with --all ``` +#### `--run_config [TUNED_CSV]` +- **Type**: Optional argument +- **Default**: disabled +- **Description**: Run production-operator benchmark only and exit (no tuning). + - `--run_config /path/to/tuned.csv`: read shapes from that tuned CSV and run tuned kernels from that file. + - `--run_config` (no path): read shapes from `-i/--untune_file` and run default kernels. + +**Examples**: +```bash +# benchmark tuned kernels from specified tuned config +python3 csrc/ck_gemm_a8w8_blockscale_bpreshuffle/gemm_a8w8_blockscale_bpreshuffle_tune.py \ + --run_config aiter/configs/a8w8_blockscale_bpreshuffle_tuned_gemm.csv + +# benchmark default kernels using shapes from -i +python3 csrc/ck_gemm_a8w8_blockscale_bpreshuffle/gemm_a8w8_blockscale_bpreshuffle_tune.py \ + -i aiter/configs/a8w8_blockscale_bpreshuffle_untuned_gemm.csv --run_config +``` + +#### `--compare` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: Run pre-tune and post-tune production benchmark, print compare results, and keep a compare candidate CSV. + - Pre-tune reads shapes from `-i/--untune_file`. + - Post-tune uses configs written to `.candidate.csv` during the compare run. + - The final tuned CSV is only updated when `--update_improved` is also set. + - Shapes with no valid pre-run baseline can still update when the post-tune benchmark passes. + +**Example**: +```bash +--compare +``` + +#### `--update_improved` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: With `--compare`, update the final tuned CSV for shapes improved by at least `--min_improvement_pct`, or for shapes with no valid pre-run baseline when the post-tune benchmark passes. + +**Example**: +```bash +--compare --update_improved +``` + +#### `--min_improvement_pct` +- **Type**: Float +- **Default**: `3.0` +- **Description**: With `--compare --update_improved`, the minimum percentage improvement required before a compared result replaces the final tuned CSV entry when both pre/post benchmarks are valid. Shapes with no valid pre-run baseline but passing post-tune are still allowed to update. + ### Profiling Configuration #### `--warmup` diff --git a/csrc/ck_gemm_a8w8_bpreshuffle/README.md b/csrc/ck_gemm_a8w8_bpreshuffle/README.md index aa62e53256..eaf19a691d 100644 --- a/csrc/ck_gemm_a8w8_bpreshuffle/README.md +++ b/csrc/ck_gemm_a8w8_bpreshuffle/README.md @@ -125,6 +125,53 @@ If you have built gemm_a8w8_bpreshuffle kernels before tuning new GEMM shapes, p --all ``` +#### `--run_config [TUNED_CSV]` +- **Type**: Optional argument +- **Default**: disabled +- **Description**: Run production-operator benchmark only and exit (no tuning). + - `--run_config /path/to/tuned.csv`: read shapes from that tuned CSV and run tuned kernels from that file. + - `--run_config` (no path): read shapes from `-i/--untune_file` and run default kernels. + +**Examples**: +```bash +# benchmark tuned kernels from specified tuned config +python3 csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py \ + --run_config aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv + +# benchmark default kernels using shapes from -i +python3 csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py \ + -i aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv --run_config +``` + +#### `--compare` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: Run pre-tune and post-tune production benchmark, print compare results, and keep a compare candidate CSV. + - Pre-tune reads shapes from `-i/--untune_file`. + - Post-tune uses configs written to `.candidate.csv` during the compare run. + - The final tuned CSV is only updated when `--update_improved` is also set. + - Shapes with no valid pre-run baseline can still update when the post-tune benchmark passes. + +**Example**: +```bash +--compare +``` + +#### `--update_improved` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: With `--compare`, update the final tuned CSV for shapes improved by at least `--min_improvement_pct`, or for shapes with no valid pre-run baseline when the post-tune benchmark passes. + +**Example**: +```bash +--compare --update_improved +``` + +#### `--min_improvement_pct` +- **Type**: Float +- **Default**: `3.0` +- **Description**: With `--compare --update_improved`, the minimum percentage improvement required before a compared result replaces the final tuned CSV entry when both pre/post benchmarks are valid. Shapes with no valid pre-run baseline but passing post-tune are still allowed to update. + ### Profiling Configuration #### `--warmup` diff --git a/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py b/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py index fda7500b00..2b04308a83 100755 --- a/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py +++ b/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py @@ -179,21 +179,6 @@ def generate_data( return x, weight_shuffle, x_scale, w_scale, out, weight, bias_f32 -def generate_data_asm( - m, n, k, seed, dtype=dtypes.bf16, q_dtype_w=dtypes.i8, device="cuda" -): - torch.manual_seed(seed) - x = torch.randn((m, k), dtype=dtype, device=device) - weight = torch.randn((n, k), dtype=dtype, device=device) - x, x_scale = aiter.pertoken_quant(x, quant_dtype=q_dtype_w) - weight, w_scale = aiter.pertoken_quant(weight, quant_dtype=q_dtype_w) - weight_shuffle = shuffle_weight(weight, layout=(32, 16)) - bias = torch.rand([1, n], dtype=dtype, device=device) - bias_f32 = bias.to(dtypes.fp32) - out = torch.empty(m, n, dtype=dtype, device=device) - return x, weight, weight_shuffle, x_scale, w_scale, out, bias_f32 - - def libtype_list(string): values = string.split(",") for value in values: @@ -207,8 +192,16 @@ class GemmA8W8BpreShuffleTuner(GemmCommonTuner): **GemmCommonTuner.ARG_DEFAULTS, "tune_file": f"{AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE}", "untune_file": "aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv", + "config_env_name": "AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE", } + def _clear_op_caches(self): + from aiter.ops.gemm_op_a8w8 import get_GEMM_config_with_quant_type + + get_GEMM_config_with_quant_type.cache_clear() + if hasattr(get_GEMM_config_with_quant_type, "file_cache"): + get_GEMM_config_with_quant_type.file_cache.clear() + def _setup_specific_arguments(self): self.parser.add_argument( "--libtype", @@ -526,7 +519,7 @@ def tune( ): useSplitK = args.splitK mp_num = args.mp - shape_grouped = False + shape_grouped = args.shape_grouped errRatio = args.errRatio cu_num = self.get_cu_num() task = [] @@ -538,8 +531,7 @@ def tune( K = untunedf.loc[i, "K"] q_dtype_w = untunedf.loc[i, "q_dtype_w"] seed = seed + 1 - total_kernel_nums = 0 - # kernels_num = len(kernels_list_ck) + prev_task_count = len(task) info_keys = (cu_num, M, N, K, q_dtype_w) if "all" in args.libtype or "ck" in args.libtype: task.extend( @@ -567,9 +559,9 @@ def tune( ) ) - total_kernel_nums = len(task) + shape_kernel_nums = len(task) - prev_task_count - tasks_data.append((total_kernel_nums, ())) + tasks_data.append((shape_kernel_nums, ())) ret = [] if task: ret = mp_tuner( @@ -625,6 +617,56 @@ def result_to_df(self, results): resultdf = pd.concat([resultdf, temp], ignore_index=True) return resultdf + def run_config(self, args): + from aiter.ops.gemm_op_a8w8 import gemm_a8w8_bpreshuffle, gemm_a8w8_ASM + from aiter.test_common import run_perftest, checkAllclose + + untunedf = self.untunedf + results = [] + for i in range(len(untunedf)): + M = int(untunedf.loc[i, "M"]) + N = int(untunedf.loc[i, "N"]) + K = int(untunedf.loc[i, "K"]) + q_dtype_w = untunedf.loc[i, "q_dtype_w"] + shape_str = f"({M}, {N}, {K}, {q_dtype_w})" + try: + is_asm = eval(q_dtype_w) == dtypes.i8 + x, weight_shuffle, x_scale, w_scale, out, weight, bias_f32 = ( + generate_data(M, N, K, 0, dtypes.bf16, eval(q_dtype_w), is_asm) + ) + if is_asm: + out, us = run_perftest( + gemm_a8w8_ASM, + x, + weight_shuffle, + x_scale, + w_scale, + bias_f32, + num_warmup=args.warmup, + num_iters=args.iters, + ) + else: + out, us = run_perftest( + gemm_a8w8_bpreshuffle, + x, + weight_shuffle, + x_scale, + w_scale, + num_warmup=args.warmup, + num_iters=args.iters, + ) + ref = run_torch(x, weight, x_scale, w_scale, dtype=dtypes.bf16) + err_ratio = checkAllclose( + out.to(dtypes.bf16), ref, msg=f"run_config {shape_str}" + ) + status = "ok" if err_ratio <= args.errRatio else "mismatch" + results.append({"shape": shape_str, "e2e_us": us, "status": status}) + except Exception as e: + results.append( + {"shape": shape_str, "e2e_us": -1, "status": f"error:{e}"} + ) + return results + if __name__ == "__main__": ## use default key and resultList diff --git a/csrc/ck_gemm_moe_2stages_codegen/README.md b/csrc/ck_gemm_moe_2stages_codegen/README.md index 7d7e701d12..8cf1ac9041 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/README.md +++ b/csrc/ck_gemm_moe_2stages_codegen/README.md @@ -117,6 +117,53 @@ If you have built moe kernels before tuning new MoE shapes, please add `AITER_RE --all ``` +#### `--run_config [TUNED_CSV]` +- **Type**: Optional argument +- **Default**: disabled +- **Description**: Run production-operator benchmark only and exit (no tuning). + - `--run_config /path/to/tuned.csv`: read shapes from that tuned CSV and run tuned kernels from that file. + - `--run_config` (no path): read shapes from `-i/--untune_file` and run default kernels. + +**Examples**: +```bash +# benchmark tuned kernels from specified tuned config +python3 csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py \ + --run_config aiter/configs/tuned_fmoe.csv + +# benchmark default kernels using shapes from -i +python3 csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py \ + -i aiter/configs/untuned_fmoe.csv --run_config +``` + +#### `--compare` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: Run pre-tune and post-tune production benchmark, print compare results, and keep a compare candidate CSV. + - Pre-tune reads shapes from `-i/--untune_file`. + - Post-tune uses configs written to `.candidate.csv` during the compare run. + - The final tuned CSV is only updated when `--update_improved` is also set. + - Shapes with no valid pre-run baseline can still update when the post-tune benchmark passes. + +**Example**: +```bash +--compare +``` + +#### `--update_improved` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: With `--compare`, update the final tuned CSV for shapes improved by at least `--min_improvement_pct`, or for shapes with no valid pre-run baseline when the post-tune benchmark passes. + +**Example**: +```bash +--compare --update_improved +``` + +#### `--min_improvement_pct` +- **Type**: Float +- **Default**: `3.0` +- **Description**: With `--compare --update_improved`, the minimum percentage improvement required before a compared result replaces the final tuned CSV entry when both pre/post benchmarks are valid. Shapes with no valid pre-run baseline but passing post-tune are still allowed to update. + ### Profiling Configuration #### `--warmup` diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py index 633b25c7e2..1107310c1f 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py @@ -24,6 +24,8 @@ from aiter import ck_moe_stage1_fwd, ck_moe_stage2_fwd, dtype2str_dict from aiter.ops.shuffle import ( shuffle_weight, + shuffle_scale_a16w4, + shuffle_weight_a16w4, ) from aiter.utility.mp_tuner import mp_tuner from aiter.int4_utils import ( @@ -70,8 +72,20 @@ class FmoeTuner(TunerCommon): "errRatio": 0.5, "batch": 100, "profile_file": "", # for all results + "config_env_name": "AITER_CONFIG_FMOE", } + def _clear_op_caches(self): + try: + import aiter.fused_moe as fmoe_module + + if hasattr(fmoe_module, "cfg_2stages"): + fmoe_module.cfg_2stages = None + if hasattr(fmoe_module, "get_2stage_cfgs"): + fmoe_module.get_2stage_cfgs.cache_clear() + except ImportError: + pass + def _setup_specific_arguments(self): self.parser.add_argument( @@ -2121,6 +2135,200 @@ def gen_flydsl_2stages_task(self, info, blockMs): return tasks_flydsl + def run_config(self, args): + from aiter.fused_moe import fused_moe, fused_topk + from aiter.test_common import run_perftest, checkAllclose + + untunedf = self.untunedf + results = [] + for i in range(len(untunedf)): + row = untunedf.iloc[i] + token = int(row["token"]) + model_dim = int(row["model_dim"]) + inter_dim = int(row["inter_dim"]) + expert = int(row["expert"]) + topk = int(row["topk"]) + act_type = eval(row["act_type"]) + dtype = eval(row["dtype"]) + q_dtype_a = eval(row["q_dtype_a"]) + q_dtype_w = eval(row["q_dtype_w"]) + q_type = eval(row["q_type"]) + q_type = QuantType.per_1x128 if q_type == QuantType.per_128x128 else q_type + use_g1u1 = bool(row["use_g1u1"]) + doweight_stage1 = bool(row["doweight_stage1"]) + shape_str = f"({token}, {model_dim}, {inter_dim}, E={expert}, topk={topk})" + kernel_us = None + if "us" in row and pd.notna(row["us"]): + try: + kernel_us = float(row["us"]) + except (TypeError, ValueError): + kernel_us = None + try: + torch.manual_seed(0) + hidden = ( + torch.randn((token, model_dim), dtype=dtype, device="cuda") / 10 + ) + if use_g1u1: + w1 = ( + torch.randn( + (expert, inter_dim * 2, model_dim), + dtype=dtype, + device="cuda", + ) + / 10 + ) + else: + w1 = ( + torch.randn( + (expert, inter_dim, model_dim), dtype=dtype, device="cuda" + ) + / 10 + ) + w2 = torch.randn( + (expert, model_dim, inter_dim), dtype=dtype, device="cuda" + ) + w1_qt, w1_scale = self.weight_quant(w1, q_type, quant_dtype=q_dtype_w) + w2_qt, w2_scale = self.weight_quant(w2, q_type, quant_dtype=q_dtype_w) + if q_dtype_w is not dtypes.fp4x2: + w1_qt = w1_qt.view(w1.shape) + w2_qt = w2_qt.view(w2.shape) + else: + w1_qt = w1_qt.view(w1.shape[0], w1.shape[1], w1.shape[2] // 2) + w2_qt = w2_qt.view(w2.shape[0], w2.shape[1], w2.shape[2] // 2) + + # Match the production/test path used by op_tests/test_moe_2stage.py. + w1_qt_fmoe = w1_qt + w2_qt_fmoe = w2_qt + w1_scale_fmoe = w1_scale + w2_scale_fmoe = w2_scale + if q_dtype_w == torch.int4: + w1_qt_fmoe = rearrange_4bit_elements( + convert_int8_to_uint32_int4( + shuffle_weight(w1_qt_fmoe, (16, 16), use_int4=True) + ) + ) + w2_qt_fmoe = rearrange_4bit_elements( + convert_int8_to_uint32_int4( + shuffle_weight(w2_qt_fmoe, (16, 16), use_int4=True) + ) + ) + w1_scale_fmoe = ( + fp4_utils.e8m0_shuffle(w1_scale) + if w1_scale is not None + else None + ) + w2_scale_fmoe = ( + fp4_utils.e8m0_shuffle(w2_scale) + if w2_scale is not None + else None + ) + elif ( + q_type == QuantType.per_1x32 + and q_dtype_a in [dtypes.bf16, dtypes.fp16, dtypes.fp8] + and q_dtype_w == dtypes.fp4x2 + ): + w1_qt_fmoe = shuffle_weight_a16w4(w1_qt_fmoe, 16, True) + w1_scale_fmoe = shuffle_scale_a16w4(w1_scale, expert, True) + w2_qt_fmoe = shuffle_weight_a16w4(w2_qt_fmoe, 16, False) + w2_scale_fmoe = shuffle_scale_a16w4(w2_scale, expert, False) + elif q_dtype_w != dtypes.fp4x2: + w1_qt_fmoe = shuffle_weight(w1_qt_fmoe, (16, 16)) + w2_qt_fmoe = shuffle_weight(w2_qt_fmoe, (16, 16)) + w1_scale_fmoe = ( + fp4_utils.e8m0_shuffle(w1_scale) + if w1_scale is not None + else None + ) + w2_scale_fmoe = ( + fp4_utils.e8m0_shuffle(w2_scale) + if w2_scale is not None + else None + ) + else: + w1_scale_fmoe = ( + fp4_utils.e8m0_shuffle(w1_scale) + if w1_scale is not None + else None + ) + w2_scale_fmoe = ( + fp4_utils.e8m0_shuffle(w2_scale) + if w2_scale is not None + else None + ) + + w1_qt_fmoe.is_shuffled = True + w2_qt_fmoe.is_shuffled = True + + score = torch.randn((token, expert), dtype=dtype, device="cuda") + topk_weights, topk_ids = fused_topk(hidden, score, topk, True) + if q_type == QuantType.per_1x128: + a1_qt, a1_scale = aiter.pertoken_quant( + hidden.view(token, -1, 128), quant_dtype=q_dtype_a + ) + a1_qt = a1_qt.view(token, model_dim) + a1_scale = a1_scale.squeeze(-1) + elif ( + q_type == QuantType.per_1x32 + and q_dtype_a in [dtypes.bf16, dtypes.fp16] + and q_dtype_w == dtypes.fp4x2 + ): + a1_qt = hidden.to(dtype) + a1_scale = None + else: + torch_quant = aiter.get_torch_quant(q_type) + a1_qt, a1_scale = torch_quant(hidden, quant_dtype=q_dtype_a) + + out, us = run_perftest( + fused_moe, + hidden, + w1_qt_fmoe, + w2_qt_fmoe, + topk_weights, + topk_ids, + activation=act_type, + quant_type=q_type, + doweight_stage1=doweight_stage1, + w1_scale=w1_scale_fmoe, + w2_scale=w2_scale_fmoe, + dtype=dtype, + num_warmup=args.warmup, + num_iters=args.iters, + ) + ref = self.torch_moe_2stages( + a1_qt, + w1_qt, + w2_qt, + topk_weights, + topk_ids, + a1_scale=a1_scale, + w1_scale=w1_scale, + w2_scale=w2_scale, + dtype=dtype, + activation=act_type, + quant_type=q_type, + doweight_stage1=doweight_stage1, + ) + err_ratio = checkAllclose(out, ref, msg=f"run_config {shape_str}") + status = "ok" if err_ratio <= args.errRatio else "mismatch" + results.append( + { + "shape": shape_str, + "e2e_us": us, + "kernel_us": kernel_us, + "status": status, + } + ) + except Exception as e: + results.append( + { + "shape": shape_str, + "e2e_us": -1, + "kernel_us": kernel_us, + "status": f"error:{e}", + } + ) + return results + def tune( self, untunedf, diff --git a/gradlib/README.md b/gradlib/README.md index ea554be8e9..95f0af8142 100644 --- a/gradlib/README.md +++ b/gradlib/README.md @@ -1,119 +1,182 @@ ``` - _ _ _ _ - __ _ _ __ __ _ __| | (_) |__ - / _` | '__/ _` |/ _` | | | '_ \ + _ _ _ _ + __ _ _ __ __ _ __| | (_) |__ + / _` | '__/ _` |/ _` | | | '_ \ | (_| | | | (_| | (_| | | | |_) | - \__, |_| \__,_|\__,_|_|_|_.__/ - |___/ + \__, |_| \__,_|\__,_|_|_|_.__/ + |___/ +``` + +## What Is gradlib +`gradlib` is a vLLM-derived tuning toolkit for GEMM kernels. It helps you find the best kernel parameters for your current hardware to improve model inference performance. + +## Quick Start + +### 1) Capture Untuned GEMM Shapes +Replace `F.linear` with `tgemm.mm` in `aiter/tuned_gemm.py`, then run your workload: + +```bash +AITER_TUNE_GEMM=1 python {workload_tests} +``` + +Captured shapes are written to `aiter/configs/bf16_untuned_gemm.csv`. + +### 2) Tune GEMMs +Run the tuner: + +```bash +python3 gradlib/gradlib/gemm_tuner.py \ + --tuned_file aiter/configs/bf16_tuned_gemm.csv \ + --input_file aiter/configs/bf16_untuned_gemm.csv +``` + +Tuned results are saved to `aiter/configs/bf16_tuned_gemm.csv`. + +Example columns: + +|**cu_num**|**M**|**N**|**K**|**bias**|**dtype**|**outdtype**|**scaleAB**|**bpreshuffle**|**libtype**|**solidx**|**splitK**|**soltimes**|**kernelName**|**tflops**|**bw**| +|----------|-----|-----|-----|--------|---------|-----------|-----------|---------------|-----------|----------|----------|------------|--------------|----------|------| +|80|128|1536|7168|False|torch.bfloat16|torch.float32|False|False|hipblaslt|667788|0|10.6|xxxxxxx|xx|xx| + +Notes: +- `cu_num`: compute units for current GPU. +- `bpreshuffle`: whether weight is shuffled. +- `dtype`: input dtype (`hipblaslt` supports fp8/bf16/fp16; asm/triton supports bf16/fp16). +- `libtype`: kernel backend (`hipblaslt` / `rocblas` / `asm` / `triton`). +- `splitK`: valid when `libtype == asm`. +- `tflops`: throughput in TFLOPS. +- `bw`: bandwidth in GB/s. + +### 3) Run Your Workload Normally +After tuning, run your model/tests as usual. + +## More Features + +#### `-o2, --profile_file` +- **Type**: String +- **Default**: `""` (empty string) +- **Required**: No +- **Description**: Optional output file storing **all** tuning candidates (not only the best). + +**Example**: +```bash +--profile_file /path/to/all_results.csv +``` + +#### `--mp` +- **Type**: Integer +- **Default**: `torch.cuda.device_count()` +- **Description**: Number of parallel processes / GPUs used for tuning. + +**Example**: +```bash +--mp 1 +``` + +### Tuning Configuration + +#### `--errRatio` +- **Type**: Float +- **Default**: `0.05` (5%) +- **Description**: Max tolerable error ratio for valid kernels. + +**Example**: +```bash +--errRatio 0.01 +--errRatio 0.10 +``` + +#### `--sort` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: Sort output by key columns. + +**Example**: +```bash +--sort +``` + +#### `--all` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: Retune all shapes based on file relationship. + - If `tune_file == untune_file`: retune all shapes in tune file. + - If `tune_file != untune_file`: retune shapes that exist in untuned file. + +**Example**: +```bash +--all +``` + +#### `--run_config [TUNED_CSV]` +- **Type**: Optional argument +- **Default**: disabled +- **Description**: Run production benchmark only and exit (no tuning). + - `--run_config /path/to/tuned.csv`: read shapes from that tuned CSV and run tuned kernels from that file. + - `--run_config` (no path): read shapes from `--input_file` (or auto-generated shapes when `--input_file` is omitted) and run default kernels. + +**Examples**: +```bash +# benchmark tuned kernels from specified tuned config +python3 gradlib/gradlib/gemm_tuner.py \ + --input_file aiter/configs/bf16_untuned_gemm.csv \ + --run_config aiter/configs/bf16_tuned_gemm.csv + +# benchmark default kernels using shapes from --input_file +python3 gradlib/gradlib/gemm_tuner.py \ + --input_file aiter/configs/bf16_untuned_gemm.csv \ + --run_config +``` + +#### `--compare` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: Run pre-tune and post-tune production benchmark, print compare results, and keep a compare candidate CSV. + - Pre-tune reads shapes from `--input_file` (or auto-generated shapes). + - Post-tune uses configs written to `.candidate.csv` during the compare run. + - The final tuned CSV is only updated when `--update_improved` is also set. + - Shapes with no valid pre-run baseline can still update when the post-tune benchmark passes. + +**Example**: +```bash +--compare +``` + +#### `--update_improved` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: With `--compare`, update the final tuned CSV for shapes improved by at least `--min_improvement_pct`, or for shapes with no valid pre-run baseline when the post-tune benchmark passes. + +**Example**: +```bash +--compare --update_improved +``` + +#### `--min_improvement_pct` +- **Type**: Float +- **Default**: `3.0` +- **Description**: With `--compare --update_improved`, the minimum percentage improvement required before a compared result replaces the final tuned CSV entry when both pre/post benchmarks are valid. Shapes with no valid pre-run baseline but passing post-tune are still allowed to update. + +### Debugging and Verbose Output + +#### `-v, --verbose` +- **Type**: Flag (boolean) +- **Default**: `False` +- **Description**: Enable verbose logs. + +**Example**: +```bash +--verbose +-v ``` -## What is gradlib -It is a library of tools derived from vLLM for optimization and tuning, mainly used for performance tuning of matrix multiplication (GEMM). - -By gradlib, we can confirm the parameter of GEMMs with best performance in the specific hardware currently in use. As a result, we can **improve the inference speed of the model**. - -## How to use gradlib - -1. to get GEMM shapes to be tuned, replace F.linear by tgemm.mm under aiter/tuned_gemm.py, - run - - ` - AITER_TUNE_GEMM=1 python {workload_tests} - ` - - then shapes will be captured in aiter/configs/bf16_untuned_gemm.csv -2. to tune GEMMs in aiter/configs/bf16_untuned_gemm.csv, - You can find the results of this tuning in `aiter/configs/bf16_tuned_gemm.csv`. - |**cu_num**|**M**|**N**|**K**|**bias**| **dtype** | **outdtype** |**scaleAB**|**bpreshuffle**|**libtype**|**solidx**|**splitK**|**soltimes**|**kernelName**|**tflops**|**bw**| - |----------|-----|-----|-----|--------|--------------|--------------|-----------|---------------|-----------|----------|----------|------------|--------------|----------|------| - |80 |128 |1536 |7168 | False |torch.bfloat16|torch.float32 | False | False | hipblast |667788 |0 | 10.6 | xxxxxxx | xx | xx | - - `cu_num` means the number of compute units, and it is used to distinguish between graphics. - `bpreshuffle` means whether weight will be shuffled - `dtype` means the input data type, hipblaslt support fp8/bf16/fp16 tuning, asm/triton support bf16/fp16 only - `libtype` means the kernel library type: hipblaslt or rocblas or asm - `splitK` only be valid in libtype==asm - `tflops` TFLOPS - `bw` means bandwidth of the implement, GB/s - - run - - ` - python3 gradlib/gradlib/gemm_tuner.py --tuned_file aiter/configs/bf16_tuned_gemm.csv --input_file aiter/configs/bf16_untuned_gemm.csv - ` - more features: - #### `-o2, --profile_file` - - **Type**: String - - **Default**: `""` (empty string) - - **Required**: No - - **Description**: Optional output file to store **all** tuning results (not just the best ones). Useful for profiling and analyzing all kernel candidates. - - **Example**: - ```bash - --profile_file /path/to/all_results.csv - ``` - #### `--mp` - - **Type**: Integer - - **Default**: `torch.cuda.device_count()` (number of available GPUs) - - **Description**: Number of parallel processes to use for tuning across multiple GPUs. Each process runs on a separate GPU. - - **Examples**: - ```bash - --mp 1 # Single GPU tuning - ``` - ### Tuning Configuration - - #### `--errRatio` - - **Type**: Float - - **Default**: `0.05` (5%) - - **Description**: Tolerable error ratio threshold. Only kernels with error ratios below this threshold will be considered valid candidates. - - **Example**: - ```bash - --errRatio 0.01 # Stricter tolerance (1% error) - --errRatio 0.10 # More lenient tolerance (10% error) - ``` - - #### `--sort` - - **Type**: Flag (boolean) - - **Default**: `False` - - **Description**: Sort the output file according to the key columns (e.g., `cu_num`, `N`, `M`, `K` for GEMM). Useful for maintaining consistent ordering in result files. - - **Example**: - ```bash - --sort # Sort results by keys - ``` - - #### `--all` - - **Type**: Flag (boolean) - - **Default**: `False` - - **Description**: Retune all shapes based on file relationship: - - If `tune_file` == `untune_file`: Retune all shapes in the tune file - - If `tune_file` != `untune_file`: Retune shapes that exist in untuned file - - **Example**: - ```bash - --all # Retune all shapes - ``` - - ### Debugging and Verbose Output - - #### `-v, --verbose` - - **Type**: Flag (boolean) - - **Default**: `False` - - **Description**: Enable verbose output with detailed logging information, including skipped shapes, tuning progress, and detailed error messages. - - **Example**: - ```bash - --verbose # Enable verbose mode - -v # Short form - ``` -3. then run your test as normal~ ## hipBLASLt Online Tuning -The hipBLASLt GEMM online tuning feature can be enabled by setting environment variable HIP_ONLINE_TUNING. +Enable hipBLASLt online tuning: + ```bash export HIP_ONLINE_TUNING=1 ``` -The one-time overhead of online tuning will take several minutes. The result of hipBLASLt online tuning will be saved at hip_online_tuning_res.csv. + +The one-time overhead can take several minutes. Results are saved to `hip_online_tuning_res.csv`. diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/gradlib/GemmTuner.py index e188e2fa62..160695ee65 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/gradlib/GemmTuner.py @@ -113,10 +113,12 @@ def run_flydsl_gemm_bf16(input, weight, bias=None, otype=dtypes.bf16, config=Non auto_shuffle_b=False, c_to_lds=config["c_to_lds"], ) - if bias is not None: - out = out + bias if otype is not None and out.dtype != otype: out = out.to(otype) + if bias is not None: + if bias.dtype != out.dtype: + bias = bias.to(out.dtype) + out = out + bias return out @@ -814,6 +816,7 @@ class GemmTuner(GemmCommonTuner): "tune_file": f"{AITER_CONFIG_GEMM_BF16}", "untune_file": "aiter/configs/bf16_untuned_gemm.csv", "batch": 1, + "config_env_name": "AITER_CONFIG_GEMM_BF16", } def _setup_specific_arguments(self): @@ -901,6 +904,78 @@ def __init__( self.gemmobj = None self.num_warmup = 10 + def _clear_op_caches(self): + from aiter.tuned_gemm import get_GEMM_A16W16_config_, get_GEMM_A16W16_config + + get_GEMM_A16W16_config_.cache_clear() + get_GEMM_A16W16_config.cache_clear() + + def run_config(self, args): + from aiter.tuned_gemm import gemm_a16w16 + from aiter.test_common import run_perftest, checkAllclose + + untunedf = self.untunedf + results = [] + for i in range(len(untunedf)): + M = int(untunedf.loc[i, "M"]) + N = int(untunedf.loc[i, "N"]) + K = int(untunedf.loc[i, "K"]) + bias = untunedf.loc[i, "bias"] + indtype = str(untunedf.loc[i, "dtype"]) + outdtype = str(untunedf.loc[i, "outdtype"]) + scaleAB = untunedf.loc[i, "scaleAB"] + bpreshuffle = untunedf.loc[i, "bpreshuffle"] + shape_str = f"({M}, {N}, {K}, {indtype}, bias={bias})" + try: + inp, weights, _, bias_tensor, x_scale, _, shuffleweights, w_scale = ( + generate_data( + M, + N, + K, + eval(indtype), + eval(outdtype), + scaleAB, + bpreshuffle, + 0, + bias, + ) + ) + w = shuffleweights if bpreshuffle else weights + scale_a = x_scale if scaleAB else None + scale_b = w_scale if scaleAB else None + out, us = run_perftest( + gemm_a16w16, + inp, + w, + bias=bias_tensor, + otype=eval(outdtype), + scale_a=scale_a, + scale_b=scale_b, + num_warmup=args.warmup, + num_iters=args.iters, + ) + ref = get_gemm_ref( + inp, + weights, + bias_tensor, + x_scale, + w_scale, + eval(indtype), + eval(outdtype), + ) + _atol = 5e-2 if eval(outdtype) == torch.bfloat16 else 1e-2 + _rtol = 5e-2 if eval(outdtype) == torch.bfloat16 else 1e-2 + err_ratio = checkAllclose( + out, ref, atol=_atol, rtol=_rtol, msg=f"run_config {shape_str}" + ) + status = "ok" if err_ratio <= args.errRatio else "mismatch" + results.append({"shape": shape_str, "e2e_us": us, "status": status}) + except Exception as e: + results.append( + {"shape": shape_str, "e2e_us": -1, "status": f"error:{e}"} + ) + return results + def calculate_perf( self, results,