Skip to content
Merged
25 changes: 15 additions & 10 deletions aiter/utility/mp_tuner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
import torch
import multiprocessing as mp
import time
Expand Down Expand Up @@ -37,7 +37,7 @@ def worker(
us = round(us, 4)

except RuntimeError as e:
print(f"run gpu func error: info:{info}\t {e}")
print(f"run gpu func warning: info:{info}\t {e}", flush=True)
us = -1 # not support or error
max_err_ratio = 1.0
max_retries = 3
Expand Down Expand Up @@ -82,24 +82,28 @@ def worker(
max_err_ratio = max(max_err_ratio, err_ratio)
except RuntimeError as e:
if "CUDA" in str(e) or "HIP" in str(e) or "out of memory" in str(e).lower():
print(f"GPU Runtime Error in process:{pid} info:{info}: {e}")
if printLog:
print(f"GPU Runtime Error in process:{pid} info:{info}: {e}")
# Try to recover GPU state
try:
torch.cuda.empty_cache()
torch.cuda.synchronize()
except Exception as e:
print(f"Error in process:{pid} info:{info}: {e}")
if printLog:
print(f"Error in process:{pid} info:{info}: {e}")
pass
else:
print(f"Runtime Error in process:{pid} info:{info}: {e}")
us = -1 # float("inf")
max_err_ratio = 1.0
except TimeoutError as e:
print(f"Timeout in process:{pid} info:{info}: {e}")
if printLog:
print(f"Timeout in process:{pid} info:{info}: {e}")
us = float("inf")
max_err_ratio = 1.0
except Exception as e:
print(f"Unexpected Error in process:{pid} info:{info}: {e}")
if printLog:
print(f"Unexpected Error in process:{pid} info:{info}: {e}")
# import traceback

# traceback.print_exc()
Expand All @@ -109,7 +113,7 @@ def worker(
return info, us, round(max_err_ratio, 4)


def work_group(GPUIDMap, fast_mode, err_ratio, in_data, tasks, printLog=False):
def work_group(GPUIDMap, fast_mode, err_ratio, in_data, tasks, verbose=False):
"""Work group that processes a batch of related tasks."""
group_task = [tasks] if not isinstance(tasks, list) else tasks
kernels_num, (input_data) = in_data
Expand Down Expand Up @@ -204,7 +208,7 @@ def work_group(GPUIDMap, fast_mode, err_ratio, in_data, tasks, printLog=False):
)

# Run worker with explicit GPU ID
ret = worker(*work_args, tol_err_ratio=err_ratio)
ret = worker(*work_args, printLog=verbose, tol_err_ratio=err_ratio)
rets.append(ret)
return rets

Expand Down Expand Up @@ -458,7 +462,8 @@ def add_dummy_result(k, results_list):
# pool_restart_needed = True
else:
error_msg = f"[Failed] Task {k} failed with {error_type}: {e}"
# pool_restart_needed = True
failed_tasks.append((k, "timeout"))
completed_this_round.append((k, async_result))

# Only log error once per error type
if error_type not in logged_error_types:
Expand Down Expand Up @@ -515,7 +520,7 @@ def add_dummy_result(k, results_list):
# Reconstruct results in original task order
result = []
for k in range(len(rets)):
task_result = result_dict[k]
task_result = result_dict.get(k, [])
if shape_grouped:
result.extend(task_result)
else:
Expand Down
140 changes: 134 additions & 6 deletions csrc/ck_batched_gemm_a8w8/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# CK batched_gemm a8w8 tune
# CK Batched GEMM A8W8 Tune

1. Install aiter:
`cd $aiter_path`
Expand All @@ -10,15 +10,143 @@
|16 |128 |1536 |7168 |

3. Start tuning:
Run the following cmd to start tuning, run the following cmd to start tuning, please wait a few minutes as it will build batched_gemm_a8w8_tune via jit:
Run the following cmd to start tuning, please wait a few minutes as it will build batched_gemm_a8w8_tune via jit:
`python3 csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py -i aiter/configs/a8w8_untuned_batched_gemm.csv -o aiter/configs/a8w8_tuned_batched_gemm.csv`
You can find the results of the tuning in `aiter/configs/a8w8_tuned_batched_gemm.csv`.
You can find the results of the tuning in `aiter/configs/a8w8_tuned_batched_gemm.csv`, like this:
|**cu_num**|**B**|**M**|**N**|**K**|**kernelId**|**splitK**|**us**|**kernelName**|**tflops**|**bw**|**errRatio**|
|----------|-----|-----|-----|-----|------------|----------|------|--------------|----------|------|------------|
|80 |16 |128 |1536 |7168 |23 |0 |32.99 |xxxxxxxx |125.4 |89.5 |0.01 |

`cu_num` means the number of compute units, and it is used to distinguish between graphics.

4. Build tuned kernels and test:
Test the performance, modify the test instance in `op_tests/test_batched_gemm_a8w8.py` and run it, please wait a few minutes as it will build batched_gemm_a8w8 tuned kernels in `aiter/configs/a8w8_tuned_batched_gemm.csv` via jit:
`python3 op_tests/test_batched_gemm_a8w8.py`
If you have built batched_gemm_a8w8 kernels brefore tuning new GEMM shapes, please add `AITER_REBUILD=1` before your test cmd, such as `AITER_REBUILD=1 python3 op_tests/test_batched_gemm_a8w8.py`. It will rebuild kernels from `AITER_CONFIG_A8W8_BATCHED_GEMM` the default one will be `aiter/configs/a8w8_tuned_batched_gemm.csv`.
If you have built batched_gemm_a8w8 kernels before tuning new GEMM shapes, please add `AITER_REBUILD=1` before your test cmd, such as `AITER_REBUILD=1 python3 op_tests/test_batched_gemm_a8w8.py`. It will rebuild kernels from `AITER_CONFIG_A8W8_BATCHED_GEMM`, the default one will be results merged from `aiter/configs/a8w8_tuned_batched_gemm.csv` and tuned fmoe csv under `aiter/configs/model_configs/xx_a8w8_tuned_batched_gemm_xx.csv`, the merged result is store in `/tmp/aiter_configs/a8w8_tuned_batched_gemm.csv`.

## More Options

### Output Configuration

#### `-o2, --profile_file`
- **Type**: String
- **Default**: `""` (empty string)
- **Description**: Optional output file to store **all** tuning results (not just the best ones). Useful for profiling and analyzing all kernel candidates.

**Example**:
```bash
--profile_file aiter/configs/profile_a8w8_batched_all.csv
```

#### `--sort`
- **Type**: Flag (boolean)
- **Default**: `False`
- **Description**: Sort the output file according to the key columns (e.g., `cu_num`, `N`, `M`, `K` for GEMM). Useful for maintaining consistent ordering in result files.


**Example**:
```bash
--sort
```

### Tuning Configuration

#### `--errRatio`
- **Type**: Float
- **Default**: `0.05` (5%)
- **Description**: Tolerable error ratio threshold. Only kernels with error ratios below this threshold will be considered valid candidates.

**Example**:
```bash
--errRatio 0.01
```

#### `--mp`
- **Type**: Integer
- **Default**: Number of available GPUs
- **Description**: Number of parallel processes to use for tuning across multiple GPUs.

**Example**:
```bash
--mp 4
```

#### `--batch`
- **Type**: Integer
- **Default**: `100`
- **Description**: Number of shapes to tune in each batch.

**Example**:
```bash
--batch 50
```

#### `-k, --splitK`
- **Type**: Flag (boolean)
- **Default**: `False`
- **Description**: Enable split-K optimization for GEMM kernels. Split-K divides the K dimension across multiple workgroups to improve parallelism and performance for certain shapes.

**Example**:
```bash
-k
--splitK
```

#### `--all`
- **Type**: Flag (boolean)
- **Default**: `False`
- **Description**: Retune all shapes based on file relationship.
- If `tune_file` == `untune_file`: Retune all shapes in the tune file
- If `tune_file` != `untune_file`: Retune shapes that exist in untuned file


**Example**:
```bash
--all
```

### Profiling Configuration

#### `--warmup`
- **Type**: Integer
- **Default**: `5`
- **Description**: Number of warmup iterations before profiling.

**Example**:
```bash
--warmup 10
```

#### `--iters`
- **Type**: Integer
- **Default**: `101`
- **Description**: Number of profiling iterations to run for performance measurement.

**Example**:
```bash
--iters 200
```

#### `--timeout`
- **Type**: Integer
- **Default**: `None`
- **Description**: Timeout in seconds for each task group.

**Example**:
```bash
--timeout 300
```

### Debugging and Verbose Output

#### `-v, --verbose`
- **Type**: Flag (boolean)
- **Default**: `False`
- **Description**: Enable verbose output with detailed logging information.

## More
If you use flag `PREBUILD_KERNELS=1` when you install aiter, it will build gemm a8w8 kernels in tuned gemm csv by default. If you want to use the new result of gemm_a8w8_tune, please remove `build` and `*.so` in `aiter/jit` first, then re-intall aiter after finishing tune. This can take a lot of time and is not recommended.
**Example**:
```bash
-v
```
## Notes
If you use flag `PREBUILD_KERNELS=1` when you install aiter, it will build batched_gemm_a8w8 kernels in tuned gemm csv by default. If you want to use the new result of batched_gemm_a8w8_tune, please remove `build` and `*.so` in `aiter/jit` first, then re-install aiter after finishing tune. This can take a lot of time and is not recommended.
5 changes: 1 addition & 4 deletions csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
import os
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
import aiter
import pandas as pd
import torch
import torch.nn.functional as F
from aiter import dtypes
Expand Down Expand Up @@ -123,7 +121,6 @@ def tune(
kernel = kernels_list[i]
maxsplitK = (
aiter.compute_batched_gemm_SplitK(
B,
M,
N,
K,
Expand Down
Loading