Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel updated 361 files
172 changes: 161 additions & 11 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def fused_moe(
intermediate_pad=0,
bias1=None,
bias2=None,
splitk=0,
):
if not block_size_M:
block_size_M = -1
Expand Down Expand Up @@ -217,7 +218,15 @@ def fused_moe_(
quant_type = quant_remap.get(quant_type, quant_type)
q_dtype_w = w1.dtype
q_dtype_a = w1.dtype if w1.dtype != torch.uint32 else dtypes.fp8
q_dtype_a = dtypes.fp4x2 if quant_type == QuantType.per_1x32 else q_dtype_a
bf16_fp8_bound = 512
if quant_type == QuantType.per_1x32:
if activation == ActivationType.Swiglu:
if get_gfx() != "gfx950" or M < bf16_fp8_bound:
q_dtype_a = dtypes.bf16
elif M >= bf16_fp8_bound:
q_dtype_a = dtypes.fp8
else:
q_dtype_a = dtypes.fp4x2

metadata = get_2stage_cfgs(
get_padded_M(M), # consider token_num > 1024 as prefill
Expand Down Expand Up @@ -472,6 +481,33 @@ def get_block_size_M(token, topk, expert, inter_dim):
return sorted(tmp, key=lambda x: x[:2])[0][-1]


@functools.lru_cache(maxsize=2048)
def get_ksplit(token, topk, expert, inter_dim, model_dim):
aiter_ksplit = int(os.environ.get("AITER_KSPLIT", "0"))
if aiter_ksplit != 0:
return aiter_ksplit
# only for moe_blk gemm1 a8w8 decode scenario
if token * topk > expert:
return 0
cu_num = get_cu_num()
tileN = 128

tgM = token * topk # decode tile num
tgN = (inter_dim * 2 + tileN - 1) // tileN

tg_num = tgN * tgM
# if all cu already active
if tg_num >= cu_num:
return 0
tilek = 256
split_max = (cu_num + tg_num - 1) // tg_num
# at least split = 2
for i in reversed(range(2, split_max + 1)):
if (model_dim % i == 0) and ((model_dim // i) % tilek == 0):
return i
return 0


cfg_2stages = None
# fmt: off
fused_moe_1stage_dict = {
Expand Down Expand Up @@ -512,7 +548,7 @@ def nextPow2(n):
def get_padded_M(M):
padded_m = M
if M >= 1 and M <= 16:
padded_m = 16
return padded_m
Comment thread
valarLip marked this conversation as resolved.
Outdated
elif M < 1024:
padded_m = nextPow2(padded_m)
elif M < 2048:
Expand Down Expand Up @@ -620,8 +656,22 @@ def FinalFunc():
)
logger.info("\033[0m")

def use_cfg():
Comment thread
valarLip marked this conversation as resolved.
problem_type = (activation, dtype, q_dtype_a, q_dtype_w, q_type)
bypass_type = (
ActivationType.Silu,
dtypes.bf16,
dtypes.fp8,
dtypes.fp8,
QuantType.per_1x128,
)
if problem_type == bypass_type and (token * topk) <= 128: # bypass tuned
aiter.logger.info("bypass tuned results for fp8 blockscale")
return False
return True

# cfg = cfg_2stages.get(keys, None)
cfg = cfg_2stages.get(keys, None) if cfg_2stages else None
cfg = cfg_2stages.get(keys, None) if cfg_2stages and use_cfg() else None
if cfg is None and os.environ.get("AITER_ONLINE_TUNE", "0") == "1":
lock_path = os.path.join(bd_dir, f"lock_fmoe_tune_{keys}")
mp_lock(lock_path, MainFunc=MainFunc, FinalFunc=FinalFunc)
Expand All @@ -630,7 +680,7 @@ def FinalFunc():
cfg = cfg_2stages.get(keys, None) if cfg_2stages else None
if cfg is None:
logger.warning(f"Fmoe tuning not support for {keys}")
if cfg is None:
if cfg is None or int(os.environ.get("AITER_HEURISTIC_ONLY", "0")):
Comment thread
valarLip marked this conversation as resolved.
Outdated
ksplit = 0
kernelName1 = ""
kernelName2 = ""
Expand All @@ -645,7 +695,7 @@ def FinalFunc():
doweight_stage1,
) in fused_moe_1stage_dict[get_gfx()]:
if q_type == QuantType.per_1x128:
run_1stage = True and (inter_dim % 256 == 0)
run_1stage = token > 32 and (inter_dim % 256 == 0)
Comment thread
valarLip marked this conversation as resolved.
elif q_type == QuantType.per_Token and q_dtype_w == dtypes.i8:
run_1stage = token > 32
elif q_type == QuantType.per_Token and q_dtype_w == dtypes.fp8:
Expand All @@ -657,11 +707,23 @@ def FinalFunc():
BLOCK_SIZE_M
if run_1stage
else (
64
(64 if token > 32 else 16)
if q_type == QuantType.per_1x128
else get_block_size_M(token, topk, expert, inter_dim)
)
)
ksplit = (
ksplit
if (run_1stage)
else (
get_ksplit(token, topk, expert, inter_dim, model_dim)
if q_type == QuantType.per_1x128
else ksplit
)
)
aiter.logger.info(
f"run_1stage = {run_1stage}, ksplit = {ksplit} q_type = {q_type}"
)
else:
block_m = cfg["block_m"]
ksplit = cfg["ksplit"]
Expand All @@ -673,6 +735,13 @@ def FinalFunc():
logger.info(
f"[fused_moe] using {'1stage' if run_1stage else '2stage'} {'default' if cfg is None else tag} for {keys} "
)

def get_block_m() -> int:
if q_dtype_a == dtypes.fp8:
return 32
else:
return 16 if token < 2048 else 32 if token < 16384 else 64

if run_1stage:
return MOEMetadata(
functools.partial(
Expand Down Expand Up @@ -704,7 +773,7 @@ def FinalFunc():
k_pad_zeros=intermediate_pad // 128 * 128,
bias2=bias2,
),
16 if token < 2048 else 32 if token < 16384 else 64,
get_block_m(),
ksplit,
False,
)
Expand All @@ -717,14 +786,16 @@ def FinalFunc():
dtypes.fp16,
torch.uint32,
dtypes.fp4x2,
dtypes.fp8,
]
):
return MOEMetadata(
functools.partial(
aiter.ck_moe_stage1_fwd,
ck_moe_stage1,
kernelName=kernelName1,
activation=activation,
quant_type=q_type,
splitk=ksplit,
),
functools.partial(
aiter.ck_moe_stage2_fwd,
Expand Down Expand Up @@ -818,11 +889,23 @@ def fused_moe_2stages(
if (
quant_type == QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and q_dtype_a in [dtypes.bf16, dtypes.fp16]
and w1.dtype == dtypes.fp4x2
and activation == ActivationType.Swiglu
):
a1 = hidden_states.to(dtype)
a1_scale = None
elif (
quant_type == aiter.QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and q_dtype_a == dtypes.fp8
and w1.dtype == dtypes.fp4x2
and activation == aiter.ActivationType.Swiglu
):
a1 = hidden_states.to(dtypes.fp8)
M = sorted_ids.shape[0]
N = a1.shape[-1]
a1_scale = torch.ones([M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device)
elif quant_type == QuantType.per_1x32:
if token_num <= token_num_quant_moe_sort_switch:
a1, a1_scale = fused_dynamic_mxfp4_quant_moe_sort(
Expand Down Expand Up @@ -886,17 +969,29 @@ def fused_moe_2stages(
topk,
block_m=block_size_M,
a1_scale=a1_scale,
w1_scale=w1_scale,
w1_scale=(
w1_scale.view(dtypes.fp8_e8m0) if w1.dtype == dtypes.fp4x2 else w1_scale
),
sorted_weights=sorted_weights if doweight_stage1 else None,
)

if (
quant_type == QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and q_dtype_a in [dtypes.bf16, dtypes.fp16]
and w1.dtype == dtypes.fp4x2
and activation == ActivationType.Swiglu
):
a2_scale = None
elif (
quant_type == aiter.QuantType.per_1x32
and dtype in [dtypes.bf16]
and q_dtype_a == dtypes.fp8
and w1.dtype == dtypes.fp4x2
and activation == aiter.ActivationType.Swiglu
):
a2 = a2.to(dtypes.fp8)
a2_scale = a1_scale
elif quant_type == QuantType.per_1x32:
a2 = a2.view(-1, inter_dim)
if token_num <= token_num_quant_moe_sort_switch:
Expand Down Expand Up @@ -952,7 +1047,9 @@ def fused_moe_2stages(
num_valid_ids,
moe_out,
topk,
w2_scale=w2_scale,
w2_scale=(
w2_scale.view(dtypes.fp8_e8m0) if w2.dtype == dtypes.fp4x2 else w2_scale
),
a2_scale=a2_scale,
block_m=block_size_M,
sorted_weights=sorted_weights if not doweight_stage1 else None,
Expand Down Expand Up @@ -1293,6 +1390,59 @@ def torch_moe_stage2(
return out.sum(1).to(dtype)


def ck_moe_stage1(
hidden_states,
w1, # [E, inter_dim*2, model_dim]
w2, # [E, model_dim, inter_dim]
sorted_token_ids, # [max_num_tokens_padded]
sorted_expert_ids, # [max_num_m_blocks]
num_valid_ids, # [1]
out,
topk,
block_m,
a1_scale,
w1_scale,
kernelName="",
sorted_weights=None,
quant_type=aiter.QuantType.No,
activation=ActivationType.Gelu,
splitk=1,
):
token_num = hidden_states.shape[0]
tmp_out = (
torch.zeros(
(token_num, topk, w1.shape[1]), dtype=dtypes.fp32, device=out.device
)
if splitk > 1
else out
)
aiter.ck_moe_stage1_fwd(
hidden_states,
w1,
w2,
sorted_token_ids,
sorted_expert_ids,
num_valid_ids,
tmp_out,
topk,
kernelName,
w1_scale,
a1_scale,
block_m,
sorted_weights,
quant_type,
activation,
splitk,
out.dtype,
)
if splitk > 1:
if activation == ActivationType.Silu:
aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32).to(out.dtype))
Comment thread
valarLip marked this conversation as resolved.
else:
aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32).to(out.dtype))
return out


def cktile_moe_stage1(
hidden_states,
w1,
Expand All @@ -1319,7 +1469,7 @@ def cktile_moe_stage1(
if w1.dtype is torch.uint32:
D = D * 8
out = torch.empty(
(token_num, topk, D), dtype=hidden_states.dtype, device=hidden_states.device
(token_num, topk, D), dtype=dtypes.bf16, device=hidden_states.device
Comment thread
coderfeli marked this conversation as resolved.
Outdated
)
# print("Run cktile_moe_stage1: M=%d, N(N*2)=%d, K=%d, topk=%d, expert=%d"%(token_num, w1.shape[1], hidden_states.shape[1], topk, w1.shape[0]))
aiter.moe_cktile2stages_gemm1(
Expand Down
Loading
Loading