Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
444c416
Add run_config/compare support to GemmTuner (bf16)
yzhou103 Mar 20, 2026
15d667a
Add --run_config and --compare benchmark support to all tuners
yzhou103 Mar 20, 2026
1b2a7f1
Revert unintended composable_kernel submodule change
yzhou103 Mar 20, 2026
c5b13a1
Merge branch 'main' into add_run_configs_in_tuner
yzhou103 Mar 23, 2026
a7e1f4a
Fix review comments and remove intermediate plan docs
yzhou103 Mar 20, 2026
211f3cf
update ref rtol,atol
yzhou103 Mar 20, 2026
2a18861
Fix tuner cache invalidation, run_config preshuffle, and compare work…
yzhou103 Mar 24, 2026
3fc8d99
Merge branch 'main' into add_run_configs_in_tuner
yzhou103 Mar 24, 2026
d982177
fix format
yzhou103 Mar 24, 2026
1aa4c81
fix format
yzhou103 Mar 24, 2026
46d804b
update readme
yzhou103 Mar 24, 2026
500cadc
fix lint error
yzhou103 Mar 24, 2026
883bd44
fix lint
yzhou103 Mar 24, 2026
4ea7fac
Merge branch 'main' into add_run_configs_in_tuner
yzhou103 Mar 25, 2026
9c5b0df
Merge branch 'main' into add_run_configs_in_tuner
yzhou103 Mar 25, 2026
56add0b
Merge branch 'main' into add_run_configs_in_tuner
yzhou103 Mar 30, 2026
eb2c56d
Merge remote-tracking branch 'origin/add_run_configs_in_tuner' into a…
yzhou103 Mar 30, 2026
887b655
Merge branch 'main' into add_run_configs_in_tuner
yzhou103 Mar 30, 2026
8536ea3
Merge branch 'main' into add_run_configs_in_tuner
yzhou103 Apr 3, 2026
5df1d51
Merge branch 'main' into add_run_configs_in_tuner
yzhou103 Apr 10, 2026
75c48eb
update csv only when perf improves
yzhou103 Apr 10, 2026
e6fbd7f
format
yzhou103 Apr 10, 2026
7f23830
fix lint
yzhou103 Apr 10, 2026
44a1d69
Merge branch 'main' into add_run_configs_in_tuner
yzhou103 Apr 10, 2026
668f313
revert format for some files
yzhou103 Apr 10, 2026
85c04f2
Merge branch 'main' into add_run_configs_in_tuner
yzhou103 Apr 10, 2026
e18b56f
clarify compare and gated update flow
yzhou103 Apr 10, 2026
448016a
Merge remote-tracking branch 'origin/add_run_configs_in_tuner' into a…
yzhou103 Apr 10, 2026
f893bd1
fix flydsl GemmTuner review issues
yzhou103 Apr 10, 2026
ee993a0
update
yzhou103 Apr 11, 2026
92fc4b7
revert claude md
yzhou103 Apr 11, 2026
0d188c8
update shape_grouped
yzhou103 Apr 12, 2026
6339379
Merge branch 'main' into add_run_configs_in_tuner
yzhou103 Apr 13, 2026
aacbb6b
fix format
yzhou103 Apr 13, 2026
4e1d287
fix bug
yzhou103 Apr 13, 2026
e7d1b93
fix lint error
yzhou103 Apr 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,12 @@ def FinalFunc():
kernelName1 = cfg["kernelName1"]
kernelName2 = cfg["kernelName2"]
run_1stage = cfg.get("run_1stage", False)
if not is_shuffled and not run_1stage:
logger.warning(
f"[fused_moe] tuned config found for {keys} but is_shuffled=False. "
"Tuned kernels are optimized for preshuffled weights (preshuffle_on). "
"Running with preshuffle_off may produce incorrect results."
)

tag = f"({kernelName1=}, {kernelName2=})"
logger.info(
Expand Down
879 changes: 856 additions & 23 deletions aiter/utility/base_tuner.py

Large diffs are not rendered by default.

73 changes: 26 additions & 47 deletions aiter/utility/mp_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,60 +290,34 @@ def mp_tuner(
task_group = []
# dispatch per shape to one pid
if shape_grouped:
# Group tasks by info_keys (info[0])
from collections import OrderedDict

info_key_groups = OrderedDict()

for task in tasks:
# Extract info_keys from task (task[0] is info, task[0][0] is info_keys)
info_keys = task[0][0] if task and len(task) > 0 else None

if info_keys not in info_key_groups:
info_key_groups[info_keys] = []
info_key_groups[info_keys].append(task)

# Convert to list of groups
task_group = list(info_key_groups.values())
print(
f"[Task Grouping] Grouped {len(tasks)} tasks into {len(task_group)} groups by info_keys"
)

# Update in_datas to reflect the actual group sizes
# Each group gets one entry with (group_size, original_data)
new_in_datas = []
for group_idx, group in enumerate(task_group):
group_size = len(group)
# Use the first task's data configuration, or keep original if within bounds
if group_idx < len(in_datas):
original_data = (
in_datas[group_idx][1] if len(in_datas[group_idx]) > 1 else None
)
else:
original_data = (
in_datas[0][1] if in_datas and len(in_datas[0]) > 1 else None
)
new_in_datas.append((group_size, original_data))

in_datas = new_in_datas
print(
f"[in_datas] Updated to {len(in_datas)} entries with group sizes: {[size for size, _ in in_datas]}"
)
# in_datas already has one entry per shape from the tuner;
# just verify cardinality matches and use it directly.
assert len(task_group) == len(
in_datas
), f"shape_grouped: group count ({len(task_group)}) != in_datas count ({len(in_datas)})"
ref_data_index = list(range(len(task_group)))
else:
task_group = tasks
import numpy as np

# to get index of input data for task_group
import numpy as np

ref_data_index = [i for i in range(len(in_datas))]
if not shape_grouped:
cumulative = np.cumsum([size for size, _ in in_datas])
ref_data_index = np.searchsorted(
cumulative, np.arange(len(task_group)), side="right"
)
else:
# For shape_grouped, each group directly maps to its in_data entry
ref_data_index = list(range(len(task_group)))

print(f"Distributing {len(task_group)} task groups across {mp_num} GPUs")

Expand Down Expand Up @@ -410,7 +384,8 @@ def add_dummy_result(k, results_list):
while remaining_tasks:
completed_this_round = []
dummy_failed_tasks = []
timeout_count_this_round = 0 # Track timeouts in this round
consecutive_timeouts = 0
half_gpu = max(1, (mp_num + 1) // 2)

for k, async_result in remaining_tasks:
try:
Expand All @@ -430,6 +405,7 @@ def add_dummy_result(k, results_list):
# Task completed successfully
result_dict[k] = task_result
completed_this_round.append((k, async_result))
consecutive_timeouts = 0
elapsed = time.time() - task_start_times[k]
if verbose:
print(
Expand All @@ -442,7 +418,7 @@ def add_dummy_result(k, results_list):
elapsed = time.time() - task_start_times[k]

if elapsed > timeout:
timeout_count_this_round += 1
consecutive_timeouts += 1

error_msg = f"[!] Task {k} timed out after {elapsed:.1f}s (limit: {timeout}s) - likely GPU hang or infinite loop"
print(error_msg)
Expand All @@ -459,25 +435,26 @@ def add_dummy_result(k, results_list):
# Trigger pool restart for timeout (similar to crash)
pool_restart_needed = True

# If mp_num tasks timed out, all GPUs are likely stuck - restart immediately
if timeout_count_this_round >= mp_num:
# If half the GPUs worth of consecutive timeouts, pool is in bad shape
if consecutive_timeouts >= half_gpu:
print(
f"\n[!] {timeout_count_this_round} tasks timed out (all {mp_num} GPUs likely stuck)"
f"\n[!] {consecutive_timeouts} consecutive tasks timed out (>= {half_gpu}/{mp_num} GPUs likely stuck)"
)
print("[!] Triggering immediate pool restart...\n")
break
else:
consecutive_timeouts = 0

except Exception as e:
# Check if it's a process crash (segfault, memory fault, etc.)
error_type = type(e).__name__

# Special handling for KeyError (PID mapping issue)
is_mapping_error = error_type == "KeyError"

# not restart as this is not root use
if is_mapping_error:
error_msg = f"[Mapping Error] Task {k} - Process PID not in GPU map (triggering pool restart): {error_type} - {e}"
error_msg = f"[Mapping Error] Task {k} - Process PID not in GPU map: {error_type} - {e}"
dummy_failed_tasks.append((k, "mapping error"))
# pool_restart_needed = True
elif error_type == "AcceleratorError":
# GPU fault (e.g. illegal memory access): worker returns exception instead of
# hanging. Unlike hang->timeout, the faulting worker may stay alive and accept
Expand All @@ -497,7 +474,6 @@ def add_dummy_result(k, results_list):
break
else:
error_msg = f"[Failed] Task {k} failed with {error_type}: {e}"
failed_tasks.append((k, "timeout"))
failed_tasks.append((k, "unknown error"))

# Always record a dummy result so reconstruction never sees an empty list
Expand All @@ -523,16 +499,18 @@ def add_dummy_result(k, results_list):
if pool_restart_needed and remaining_tasks:
if verbose:
print(f"\n{'='*60}")
print("? Pool restart needed due to crash. Restarting pool...")
print(f"Remaining tasks: {len(remaining_tasks)}")
print(f"{'='*60}\n")
print(
"? Pool restart needed due to crash. Restarting pool...", flush=True
)
print(f"Remaining tasks: {len(remaining_tasks)}", flush=True)
print(f"{'='*60}\n", flush=True)

# Terminate old pool
try:
pool.terminate()
pool.join()
except Exception as e:
print(f"Warning: Error during pool termination: {e}")
print(f"Warning: Error during pool termination: {e}", flush=True)
# Create new pool
pool = mp.Pool(processes=parallel_num)

Expand All @@ -554,7 +532,8 @@ def add_dummy_result(k, results_list):
# Reset pool restart flag
pool_restart_needed = False
print(
f"Pool restarted. Continuing with {len(remaining_tasks)} remaining tasks...\n"
f"Pool restarted. Continuing with {len(remaining_tasks)} remaining tasks...\n",
flush=True,
)

# Small sleep to avoid busy waiting
Expand Down
47 changes: 47 additions & 0 deletions csrc/ck_batched_gemm_a8w8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,53 @@ If you have built batched_gemm_a8w8 kernels before tuning new GEMM shapes, pleas
--all
```

#### `--run_config [TUNED_CSV]`
- **Type**: Optional argument
- **Default**: disabled
- **Description**: Run production-operator benchmark only and exit (no tuning).
- `--run_config /path/to/tuned.csv`: read shapes from that tuned CSV and run tuned kernels from that file.
- `--run_config` (no path): read shapes from `-i/--untune_file` and run default kernels.

**Examples**:
```bash
# benchmark tuned kernels from specified tuned config
python3 csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py \
--run_config aiter/configs/a8w8_tuned_batched_gemm.csv

# benchmark default kernels using shapes from -i
python3 csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py \
-i aiter/configs/a8w8_untuned_batched_gemm.csv --run_config
```

#### `--compare`
- **Type**: Flag (boolean)
- **Default**: `False`
- **Description**: Run pre-tune and post-tune production benchmark, print compare results, and keep a compare candidate CSV.
- Pre-tune reads shapes from `-i/--untune_file`.
- Post-tune uses configs written to `<tune_file>.candidate.csv` during the compare run.
- The final tuned CSV is only updated when `--update_improved` is also set.
- Shapes with no valid pre-run baseline can still update when the post-tune benchmark passes.

**Example**:
```bash
--compare
```

#### `--update_improved`
- **Type**: Flag (boolean)
- **Default**: `False`
- **Description**: With `--compare`, update the final tuned CSV for shapes improved by at least `--min_improvement_pct`, or for shapes with no valid pre-run baseline when the post-tune benchmark passes.

**Example**:
```bash
--compare --update_improved
```

#### `--min_improvement_pct`
- **Type**: Float
- **Default**: `3.0`
- **Description**: With `--compare --update_improved`, the minimum percentage improvement required before a compared result replaces the final tuned CSV entry when both pre/post benchmarks are valid. Shapes with no valid pre-run baseline but passing post-tune are still allowed to update.

### Profiling Configuration

#### `--warmup`
Expand Down
54 changes: 47 additions & 7 deletions csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,53 @@ class BatchedGemma8W8Tuner(GemmCommonTuner):
"errRatio": 0.05,
"batch": 100,
"profile_file": "",
"config_env_name": "AITER_CONFIG_A8W8_BATCHED_GEMM",
}

def _clear_op_caches(self):
from aiter.ops.batched_gemm_op_a8w8 import get_CKBatchedGEMM_config

get_CKBatchedGEMM_config.cache_clear()
if hasattr(get_CKBatchedGEMM_config, "ck_batched_gemm_dict"):
del get_CKBatchedGEMM_config.ck_batched_gemm_dict

def _setup_specific_arguments(self):
pass

def run_config(self, args):
from aiter.ops.batched_gemm_op_a8w8 import batched_gemm_a8w8
from aiter.test_common import run_perftest, checkAllclose

untunedf = self.untunedf
results = []
for i in range(len(untunedf)):
B = int(untunedf.loc[i, "B"])
M = int(untunedf.loc[i, "M"])
N = int(untunedf.loc[i, "N"])
K = int(untunedf.loc[i, "K"])
shape_str = f"({B}, {M}, {N}, {K})"
try:
x, weight, x_scale, w_scale, out = generate_data(B, M, N, K)
out, us = run_perftest(
batched_gemm_a8w8,
x,
weight,
x_scale,
w_scale,
out,
num_warmup=args.warmup,
num_iters=args.iters,
)
ref = run_torch(x, weight, x_scale, w_scale)
err_ratio = checkAllclose(out, ref, msg=f"run_config {shape_str}")
status = "ok" if err_ratio <= args.errRatio else "mismatch"
results.append({"shape": shape_str, "e2e_us": us, "status": status})
except Exception as e:
results.append(
{"shape": shape_str, "e2e_us": -1, "status": f"error:{e}"}
)
return results

def calculate(self, results, bpes=(1, 1, 2)):
info, time, err_ratio = results
if time == -1:
Expand Down Expand Up @@ -95,10 +137,9 @@ def tune(
tunedf,
args,
):
issorted = args.sort
useSplitK = args.splitK
mp_num = args.mp
shape_grouped = False
shape_grouped = args.shape_grouped
errRatio = args.errRatio
cu_num = self.get_cu_num()
task = []
Expand All @@ -116,8 +157,8 @@ def tune(
)
# kernelId, splitK, time = tune_batched_gemm(B, M, N, K, useSplitK)
total_kernel_nums = 0
for i in range(kernels_num):
kernel = kernels_list[i]
for kid in range(kernels_num):
kernel = kernels_list[kid]
maxsplitK = (
aiter.compute_batched_gemm_SplitK(
M,
Expand All @@ -131,7 +172,7 @@ def tune(
else 0
)
for splitK in range(maxsplitK + 1):
info = ((cu_num, B, M, N, K), i, splitK, "")
info = ((cu_num, B, M, N, K), kid, splitK, "")
task.append(
(
info,
Expand All @@ -140,7 +181,7 @@ def tune(
kernel_instance_test,
(
[0, 1, 2, 3, 4],
i,
kid,
splitK,
), # [0, 1, 2, 3, 4] is index of paramters for kernel_instance_test in generate_data
{
Expand All @@ -160,7 +201,6 @@ def tune(
tasks_data.append((total_kernel_nums, ()))
ret = []
if task:
shape_grouped = False
ret = mp_tuner(
task,
tasks_data,
Expand Down
47 changes: 47 additions & 0 deletions csrc/ck_batched_gemm_bf16/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,53 @@ If you have built batched_gemm_bf16 kernels before tuning new GEMM shapes, pleas
--all
```

#### `--run_config [TUNED_CSV]`
- **Type**: Optional argument
- **Default**: disabled
- **Description**: Run production-operator benchmark only and exit (no tuning).
- `--run_config /path/to/tuned.csv`: read shapes from that tuned CSV and run tuned kernels from that file.
- `--run_config` (no path): read shapes from `-i/--untune_file` and run default kernels.

**Examples**:
```bash
# benchmark tuned kernels from specified tuned config
python3 csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py \
--run_config aiter/configs/bf16_tuned_batched_gemm.csv

# benchmark default kernels using shapes from -i
python3 csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py \
-i aiter/configs/bf16_untuned_batched_gemm.csv --run_config
```

#### `--compare`
- **Type**: Flag (boolean)
- **Default**: `False`
- **Description**: Run pre-tune and post-tune production benchmark, print compare results, and keep a compare candidate CSV.
- Pre-tune reads shapes from `-i/--untune_file`.
- Post-tune uses configs written to `<tune_file>.candidate.csv` during the compare run.
- The final tuned CSV is only updated when `--update_improved` is also set.
- Shapes with no valid pre-run baseline can still update when the post-tune benchmark passes.

**Example**:
```bash
--compare
```

#### `--update_improved`
- **Type**: Flag (boolean)
- **Default**: `False`
- **Description**: With `--compare`, update the final tuned CSV for shapes improved by at least `--min_improvement_pct`, or for shapes with no valid pre-run baseline when the post-tune benchmark passes.

**Example**:
```bash
--compare --update_improved
```

#### `--min_improvement_pct`
- **Type**: Float
- **Default**: `3.0`
- **Description**: With `--compare --update_improved`, the minimum percentage improvement required before a compared result replaces the final tuned CSV entry when both pre/post benchmarks are valid. Shapes with no valid pre-run baseline but passing post-tune are still allowed to update.

### Profiling Configuration

#### `--warmup`
Expand Down
Loading
Loading