Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
125 changes: 65 additions & 60 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,49 +68,8 @@ def mp_lock(
AITER_LOG_MORE = int(os.getenv("AITER_LOG_MORE", 0))
AITER_LOG_TUNED_CONFIG = int(os.getenv("AITER_LOG_TUNED_CONFIG", 0))

# config_env start here
AITER_CONFIG_GEMM_A4W4 = os.getenv(
"AITER_CONFIG_GEMM_A4W4",
f"{AITER_ROOT_DIR}/aiter/configs/a4w4_blockscale_tuned_gemm.csv",
)
AITER_CONFIG_GEMM_A8W8 = os.getenv(
"AITER_CONFIG_GEMM_A8W8",
f"{AITER_ROOT_DIR}/aiter/configs/a8w8_tuned_gemm.csv",
)
AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE = os.getenv(
"AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE",
f"{AITER_ROOT_DIR}/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv",
)
AITER_CONFIG_GEMM_A8W8_BLOCKSCALE = os.getenv(
"AITER_CONFIG_GEMM_A8W8_BLOCKSCALE",
f"{AITER_ROOT_DIR}/aiter/configs/a8w8_blockscale_tuned_gemm.csv",
)
AITER_CONFIG_FMOE = os.getenv(
"AITER_CONFIG_FMOE",
f"{AITER_ROOT_DIR}/aiter/configs/tuned_fmoe.csv",
)

AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE = os.getenv(
"AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE",
f"{AITER_ROOT_DIR}/aiter/configs/a8w8_blockscale_bpreshuffle_tuned_gemm.csv",
)

AITER_CONFIG_A8W8_BATCHED_GEMM = os.getenv(
"AITER_CONFIG_A8W8_BATCHED_GEMM",
f"{AITER_ROOT_DIR}/aiter/configs/a8w8_tuned_batched_gemm.csv",
)

AITER_CONFIG_BF16_BATCHED_GEMM = os.getenv(
"AITER_CONFIG_BATCHED_GEMM_BF16",
f"{AITER_ROOT_DIR}/aiter/configs/bf16_tuned_batched_gemm.csv",
)

AITER_CONFIG_GEMM_BF16 = os.getenv(
"AITER_CONFIG_GEMM_BF16",
f"{AITER_ROOT_DIR}/aiter/configs/tuned_gemm.csv",
)


# config_env start here
def update_config_files(file_path: str, merge_name: str):
path_list = file_path.split(os.pathsep) if file_path else []
if len(path_list) <= 1:
Expand All @@ -133,38 +92,84 @@ def update_config_files(file_path: str, merge_name: str):
else:
print(f"path {i+1}: {path} (not exist)")
merge_df = pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame()
merge_df = merge_df.drop_duplicates(keep="last")
## get keys from untuned file to drop_duplicates
untuned_name = (
re.sub(r"(?:_)?tuned$", r"\1untuned", merge_name)
if re.search(r"(?:_)?tuned$", merge_name)
else merge_name.replace("tuned", "untuned")
)
untuned_path = f"{AITER_ROOT_DIR}/aiter/configs/{untuned_name}.csv"
if os.path.exists(untuned_path):
untunedf = pd.read_csv(untuned_path)
keys = untunedf.columns
merge_df = merge_df.drop_duplicates(subset=keys, keep="last")
else:
logger.warning(
f"Untuned config file not found: {untuned_path}. Using all columns for deduplication."
)
new_file_path = f"/tmp/{merge_name}.csv"
merge_df.to_csv(new_file_path, index=False)
return new_file_path


AITER_CONFIG_GEMM_A4W4_FILE = update_config_files(
AITER_CONFIG_GEMM_A4W4, "a4w4_blockscale_tuned_gemm"
def get_config_file(env_name, tuned_file_name):
config_env_file = os.getenv(env_name)
default_file = f"{AITER_ROOT_DIR}/aiter/configs/{tuned_file_name}.csv"
from pathlib import Path

if not config_env_file:
model_config_dir = Path(f"{AITER_ROOT_DIR}/aiter/configs/model_configs/")
op_tuned_file_list = [
p
for p in model_config_dir.glob(f"*{tuned_file_name}*")
if (p.is_file() and "untuned" not in str(p))
]

if not op_tuned_file_list:
config_file = default_file
else:
tuned_files = ":".join(str(p) for p in op_tuned_file_list)
tuned_files = default_file + ":" + tuned_files
print(f"merge tuned file under model_configs/ and configs/")
config_file = update_config_files(tuned_files, tuned_file_name)
else:
config_file = update_config_files(config_env_file, tuned_file_name)
print(f"get {env_name} from environment ", config_file)
return config_file


AITER_CONFIG_GEMM_A4W4_FILE = get_config_file(
"AITER_CONFIG_GEMM_A4W4", "a4w4_blockscale_tuned_gemm"
)
AITER_CONFIG_GEMM_A8W8_FILE = update_config_files(
AITER_CONFIG_GEMM_A8W8, "a8w8_tuned_gemm"

AITER_CONFIG_GEMM_A8W8_FILE = get_config_file(
"AITER_CONFIG_GEMM_A8W8", "a8w8_tuned_gemm"
)
AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE = update_config_files(
AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE, "a8w8_bpreshuffle_tuned_gemm"
AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE = get_config_file(
"AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE", "a8w8_bpreshuffle_tuned_gemm"
)
AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE = update_config_files(
AITER_CONFIG_GEMM_A8W8_BLOCKSCALE, "a8w8_blockscale_tuned_gemm"
AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE = get_config_file(
"AITER_CONFIG_GEMM_A8W8_BLOCKSCALE", "a8w8_blockscale_tuned_gemm"
)
AITER_CONFIG_FMOE_FILE = update_config_files(AITER_CONFIG_FMOE, "tuned_fmoe")
AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE = update_config_files(
AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE,
AITER_CONFIG_FMOE_FILE = get_config_file("AITER_CONFIG_FMOE", "tuned_fmoe")

AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE = get_config_file(
"AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE",
"a8w8_blockscale_bpreshuffle_tuned_gemm",
)
AITER_CONFIG_A8W8_BATCHED_GEMM_FILE = update_config_files(
AITER_CONFIG_A8W8_BATCHED_GEMM, "a8w8_tuned_batched_gemm"

AITER_CONFIG_A8W8_BATCHED_GEMM_FILE = get_config_file(
"AITER_CONFIG_A8W8_BATCHED_GEMM", "a8w8_tuned_batched_gemm"
)
AITER_CONFIG_BF16_BATCHED_GEMM_FILE = update_config_files(
AITER_CONFIG_BF16_BATCHED_GEMM, "bf16_tuned_batched_gemm"

AITER_CONFIG_BF16_BATCHED_GEMM_FILE = get_config_file(
"AITER_CONFIG_BATCHED_GEMM_BF16", "bf16_tuned_batched_gemm"
)
AITER_CONFIG_GEMM_BF16_FILE = update_config_files(
AITER_CONFIG_GEMM_BF16, "bf16_tuned_gemm"

AITER_CONFIG_GEMM_BF16_FILE = get_config_file(
"AITER_CONFIG_GEMM_BF16", "bf16_tuned_gemm"
)

# config_env end here

find_aiter = importlib.util.find_spec("aiter")
Expand Down
2 changes: 1 addition & 1 deletion aiter/tuned_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class TunedGemm:
def __init__(self):
self.extensions_created = False
self.save_gemm = int(os.environ.get("AITER_TUNE_GEMM", 0))
self.untune_path = f"{this_dir}/configs/untuned_gemm.csv"
self.untune_path = f"{this_dir}/configs/bf16_untuned_gemm.csv"
self.tune_path = AITER_CONFIG_GEMM_BF16_FILE

def mm(
Expand Down
8 changes: 4 additions & 4 deletions gradlib/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ By gradlib, we can confirm the parameter of GEMMs with best performance in the s
AITER_TUNE_GEMM=1 python {workload_tests}
`

then shapes will be captured in aiter/configs/untuned_gemm.csv
2. to tune GEMMs in aiter/configs/untuned_gemm.csv,
You can find the results of this tuning in `aiter/configs/tuned_gemm.csv`.
then shapes will be captured in aiter/configs/bf16_untuned_gemm.csv
2. to tune GEMMs in aiter/configs/bf16_untuned_gemm.csv,
You can find the results of this tuning in `aiter/configs/bf16_tuned_gemm.csv`.
|**cu_num**|**M**|**N**|**K**|**bias**| **dtype** | **outdtype** |**scaleAB**|**libtype**|**solidx**|**splitK**|**soltimes**|**kernelName**|**tflops**|**bw**|
|----------|-----|-----|-----|--------|--------------|--------------|-----------|-----------|----------|----------|------------|--------------|----------|------|
|80 |128 |1536 |7168 | False |torch.bfloat16|torch.float32 | False | hipblast |667788 |0 | 10.6 | xxxxxxx | xx | xx |
Expand All @@ -37,6 +37,6 @@ By gradlib, we can confirm the parameter of GEMMs with best performance in the s
run

`
python3 gradlib/gradlib/gemm_tuner.py --tuned_file aiter/configs/tuned_gemm.csv --input_file aiter/configs/untuned_gemm.csv
python3 gradlib/gradlib/gemm_tuner.py --tuned_file aiter/configs/bf16_tuned_gemm.csv --input_file aiter/configs/bf16_untuned_gemm.csv
`
3. then run your test as normal~