Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions aiter/ops/triton/utils/gemm_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ def compute_splitk_params(config: dict, K: int) -> dict:
config["SPLITK_BLOCK_SIZE"] = triton.cdiv(K, config["NUM_KSPLIT"])

if "BLOCK_SIZE_K" in config:
# If NUM_KSPLIT makes K too small, then BLOCK_K will decrease to be smaller than
# GROUP_K.
while config["NUM_KSPLIT"] > 1 and config["BLOCK_SIZE_K"] > config["SPLITK_BLOCK_SIZE"]:
config["NUM_KSPLIT"] = max(config["NUM_KSPLIT"] // 2, 1)
config["SPLITK_BLOCK_SIZE"] = triton.cdiv(K, config["NUM_KSPLIT"])

# If BLOCK_SIZE_K is still too large with NUM_KSPLIT=1, fix it to equal K dim.
if config["BLOCK_SIZE_K"] > config["SPLITK_BLOCK_SIZE"]:
config["BLOCK_SIZE_K"] = triton.next_power_of_2(config["SPLITK_BLOCK_SIZE"])

Expand Down
Loading