Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 166 additions & 10 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ def fused_moe(
num_local_tokens: Optional[torch.tensor] = None,
moe_sorting_dispatch_policy=0,
dtype=None,
):
# following for cktile support
hidden_pad=0,
intermediate_pad=0,
bias1=None,
bias2=None,
):
if not block_size_M:
block_size_M = -1
return fused_moe_(
Expand All @@ -128,6 +133,10 @@ def fused_moe(
num_local_tokens=num_local_tokens,
moe_sorting_dispatch_policy=moe_sorting_dispatch_policy,
dtype=dtype,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
bias1=bias1,
bias2=bias2,
)


Expand Down Expand Up @@ -181,6 +190,10 @@ def fused_moe_(
num_local_tokens: Optional[torch.Tensor] = None,
moe_sorting_dispatch_policy: bool = 0,
dtype: Optional[torch.dtype] = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
bias1: Optional[torch.Tensor] = None,
bias2: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# We do such convert since custom_op schema restriction on block_size_M, and Enum type
activation = ActivationType(activation)
Expand Down Expand Up @@ -223,6 +236,10 @@ def fused_moe_(
isG1U1,
activation,
doweight_stage1,
hidden_pad,
intermediate_pad,
bias1,
bias2,
)

block_size_M = metadata.block_m if block_size_M is None else block_size_M
Expand Down Expand Up @@ -255,6 +272,8 @@ def fused_moe_(
moe_buf,
isG1U1,
block_size_M,
# activation=activation,
# quant_type=quant_type,
q_dtype_a=q_dtype_a,
q_dtype_w=q_dtype_w,
w1_scale=w1_scale,
Expand Down Expand Up @@ -286,6 +305,11 @@ def fused_moe_(
a1_scale=a1_scale,
a2_scale=a2_scale,
num_local_tokens=num_local_tokens,
# following for cktile support
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
bias1=bias1,
bias2=bias2,
)


Expand Down Expand Up @@ -494,14 +518,17 @@ def get_2stage_cfgs(
use_g1u1,
activation,
doweight_stage1,
hidden_pad,
intermediate_pad,
bias1,
bias2,
):
def get_cfg_2stages(tune_file):
import pandas as pd

cfg_2stages = pd.read_csv(tune_file)
cfg_2stages = cfg_2stages.set_index(
[
"cu_num",
"token",
"model_dim",
"inter_dim",
Expand Down Expand Up @@ -548,7 +575,6 @@ def MainFunc():
f.write(
"token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1"
)

q_dtype_ws = q_dtype_w if q_dtype_w != torch.uint32 else "torch.int4"
f.write(
f"\n{token},{model_dim},{inter_dim},{expert},{topk},{activation},{dtype},{q_dtype_a},{q_dtype_ws},{q_type},{int(use_g1u1)},{int(doweight_stage1)}"
Expand Down Expand Up @@ -627,6 +653,24 @@ def FinalFunc():
ksplit,
run_1stage,
)
if dtype in [dtypes.bf16, dtypes.fp16] and q_type == QuantType.per_1x32 and activation == ActivationType.Swiglu:
return MOEMetadata(
functools.partial(
cktile_moe_stage1,
n_pad_zeros=intermediate_pad // 64 * 64 * (2 if use_g1u1 else 1),
k_pad_zeros=hidden_pad // 128 * 128,
bias1=bias1,
),
functools.partial(
cktile_moe_stage2,
n_pad_zeros=hidden_pad // 64 * 64,
k_pad_zeros=intermediate_pad // 128 * 128,
bias2=bias2,
),
16 if token < 2048 else 32,
ksplit,
False,
)
if (
"ck2stages" in kernelName1
or (q_type == QuantType.per_1x128 and doweight_stage1)
Expand Down Expand Up @@ -704,6 +748,11 @@ def fused_moe_2stages(
a1_scale=None, # [expert(local_expert:EP), 1, model_dim]
a2_scale=None, # [expert(local_expert:EP), 1, inter_dim]
num_local_tokens: Optional[torch.tensor] = None,
# following for cktile support
hidden_pad=0,
intermediate_pad=0,
bias1=None,
bias2=None,
):
quant_func = get_quant(quant_type)

Expand All @@ -725,9 +774,18 @@ def fused_moe_2stages(
isG1U1,
activation,
doweight_stage1,
hidden_pad,
intermediate_pad,
bias1,
bias2,
)

if quant_type == QuantType.per_1x32:
if quant_type == QuantType.per_1x32 \
and dtype in [dtypes.bf16, dtypes.fp16] \
and w1.dtype == dtypes.fp4x2 \
and activation == ActivationType.Swiglu:
a1 = hidden_states.to(dtype)
a1_scale = None
elif quant_type == QuantType.per_1x32:
a1, a1_scale = quant_func(
hidden_states,
scale=a1_scale,
Expand Down Expand Up @@ -768,7 +826,7 @@ def fused_moe_2stages(
dtype=dtype,
device=device,
)

a2 = metadata.stage1(
a1,
w1,
Expand All @@ -784,7 +842,11 @@ def fused_moe_2stages(
sorted_weights=sorted_weights if doweight_stage1 else None,
)

if quant_type == QuantType.per_1x32:
if quant_type == QuantType.per_1x32 \
and dtype in [dtypes.bf16, dtypes.fp16] \
and w1.dtype == dtypes.fp4x2:
a2_scale = None
elif quant_type == QuantType.per_1x32:
a2 = a2.view(-1, inter_dim)
a2, a2_scale = quant_func(
a2,
Expand Down Expand Up @@ -975,7 +1037,7 @@ def torch_moe(
return (out * topk_weight.view(B, -1, 1)).sum(dim=1).to(dtype)


# temp workaround for swiglu
#temp workaround for swiglu
def swiglu(x_glu, x_linear, alpha: float = 1.702, limit: float = 7.0):
# Clamp the input values
x_glu = x_glu.clamp(min=None, max=limit)
Expand Down Expand Up @@ -1103,7 +1165,6 @@ def torch_moe_stage2(
w2_bias=None,
doweight=True,
):
quant_type = quant_remap.get(quant_type, quant_type)
ctype = dtypes.fp32 # compute type
E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape)
if quant_type == QuantType.per_1x32:
Expand Down Expand Up @@ -1172,6 +1233,101 @@ def torch_moe_stage2(
return out.sum(1).to(dtype)


def cktile_moe_stage1(
hidden_states,
w1,
w2,
sorted_token_ids,
sorted_expert_ids,
num_valid_ids,
out,
topk,
block_m,
a1_scale,
w1_scale,
sorted_weights=None,
n_pad_zeros=0,
k_pad_zeros=0,
bias1=None,
):
token_num = hidden_states.shape[0]
_, n1, k1 = w1.shape
_, k2, n2 = w2.shape
D = n2 if k2 == k1 else n2*2 #bit4 format
# max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size

if w1.dtype is torch.uint32:
D = D * 8
out = torch.empty((token_num, topk, D), dtype=hidden_states.dtype, device=hidden_states.device)
# print("Run cktile_moe_stage1: M=%d, N(N*2)=%d, K=%d, topk=%d, expert=%d"%(token_num, w1.shape[1], hidden_states.shape[1], topk, w1.shape[0]))
aiter.moe_cktile2stages_gemm1(
hidden_states,
w1,
out,
sorted_token_ids,
sorted_expert_ids,
num_valid_ids,
topk,
n_pad_zeros,
k_pad_zeros,
sorted_weights,
a1_scale,
w1_scale,
bias1,
block_m,
)
return out


def cktile_moe_stage2(
a2,
w1,
w2,
sorted_token_ids,
sorted_expert_ids,
num_valid_ids,
out,
topk,
w2_scale,
a2_scale,
block_m,
sorted_weights=None,
zeros_out=False,
n_pad_zeros=0,
k_pad_zeros=0,
bias2=None,
):
token_num = a2.shape[0]
D = w2.shape[1]
# max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size

# out = torch.empty(
# (token_num, D),
# dtype=a2.dtype,
# device=a2.device,
# )
# if zeros_out:
# out.fill_(0)
# print("Run cktile_moe_stage2: M=%d, N=%d, K=%d, topk=%d, expert=%d"%(a2.shape[0]*a2.shape[1], w2.shape[1], a2.shape[2], topk, w2.shape[0]))
aiter.moe_cktile2stages_gemm2(
a2,
w2,
out,
sorted_token_ids,
sorted_expert_ids,
num_valid_ids,
topk,
n_pad_zeros,
k_pad_zeros,
sorted_weights,
a2_scale,
w2_scale,
bias2,
block_m,
)
return out


def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
Expand Down Expand Up @@ -1233,4 +1389,4 @@ def fused_topk(
# if renormalize:
# topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

return topk_weights, topk_ids
return topk_weights, topk_ids
3 changes: 2 additions & 1 deletion csrc/include/aiter_enum.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ enum class ActivationType : int
{
No = -1,
Silu = 0,
Gelu
Gelu = 1,
Swiglu = 2,
};
enum class QuantType : int
{
Expand Down
1 change: 1 addition & 0 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,7 @@ namespace py = pybind11;
.value("No", ActivationType::No) \
.value("Silu", ActivationType::Silu) \
.value("Gelu", ActivationType::Gelu) \
.value("Swiglu", ActivationType::Swiglu) \
.export_values(); \
pybind11::implicitly_convertible<int, QuantType>(); \
pybind11::implicitly_convertible<int, ActivationType>();
Expand Down
Loading