Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/aiter-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:

define-runners:
runs-on: ubuntu-latest
needs: [check-signal]
# needs: [check-signal]
outputs:
standard_runners: ${{ steps.machines.outputs.standard_runners }}
multigpu_runners: ${{ steps.machines.outputs.multigpu_runners }}
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel updated 72 files
+5 −0 .gitignore
+3 −0 CMakeLists.txt
+7 −3 Jenkinsfile
+14 −1 example/26_contraction/run_contraction_bilinear_example.inc
+14 −1 example/26_contraction/run_contraction_scale_example.inc
+1 −0 example/65_gemm_multiply_multiply/CMakeLists.txt
+2 −2 example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp
+535 −0 example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp
+2 −2 example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp
+16 −42 example/ck_tile/03_gemm/gemm_utils.hpp
+24 −19 example/ck_tile/03_gemm/universal_gemm_invoker.hpp
+10 −39 example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
+0 −18 example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp
+3 −36 example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp
+9 −39 example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
+31 −48 experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp
+7 −0 experimental/builder/include/ck_tile/builder/testing/type_traits.hpp
+7 −2 include/ck/host_utility/device_prop.hpp
+52 −0 include/ck/tensor_operation/gpu/device/device_base.hpp
+38 −15 include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+38 −1 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+27 −17 include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp
+74 −50 include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp
+24 −0 include/ck_tile/core/config.hpp
+37 −25 include/ck_tile/host/tensor_shuffle_utils.hpp
+33 −5 include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+1 −1 include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+1 −1 include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
+2 −2 include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp
+22 −0 include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
+6 −3 include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp
+1 −1 include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp
+1 −1 include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp
+33 −14 include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp
+1 −1 include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp
+1 −1 include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp
+1 −1 include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
+232 −0 library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp
+7 −1 test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+111 −15 test/ck_tile/gemm_block_scale/CMakeLists.txt
+0 −95 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp
+42 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp
+42 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp
+46 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp
+41 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp
+48 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp
+40 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp
+0 −99 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp
+41 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp
+41 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp
+41 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp
+48 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp
+49 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp
+0 −93 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp
+39 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp
+51 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp
+41 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp
+58 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp
+40 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp
+53 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp
+11 −62 test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp
+9 −9 test/ck_tile/grouped_gemm_quant/CMakeLists.txt
+246 −0 test/common/csv_test_loader.hpp
+4 −0 test/grouped_convnd_bwd_data/CMakeLists.txt
+317 −0 test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_dataset_xdl.cpp
+4 −0 test/grouped_convnd_bwd_weight/CMakeLists.txt
+258 −0 test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_dataset_xdl.cpp
+125 −290 test/grouped_convnd_fwd/test_grouped_convnd_fwd_dataset_xdl.cpp
+178 −20 test_data/generate_test_dataset.sh
+1,187 −0 test_data/gtest_parallel.py
+0 −39 tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp
+15 −6 tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp
122 changes: 117 additions & 5 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def fused_moe(
intermediate_pad=0,
bias1=None,
bias2=None,
splitk=0,
):
if not block_size_M:
block_size_M = -1
Expand Down Expand Up @@ -236,6 +237,7 @@ def fused_moe_(
intermediate_pad,
bias1,
bias2,
get_padded_M(M), # only used in 2stage
)

block_size_M = metadata.block_m if block_size_M is None else block_size_M
Expand Down Expand Up @@ -472,6 +474,33 @@ def get_block_size_M(token, topk, expert, inter_dim):
return sorted(tmp, key=lambda x: x[:2])[0][-1]


@functools.lru_cache(maxsize=2048)
def get_ksplit(token, topk, expert, inter_dim, model_dim):
aiter_ksplit = int(os.environ.get("AITER_KSPLIT", "0"))
if aiter_ksplit != 0:
return aiter_ksplit
# only for moe_blk gemm1 a8w8 decode scenario
if token * topk > expert:
return 0
cu_num = get_cu_num()
tileN = 128

tgM = token * topk # decode tile num
tgN = (inter_dim * 2 + tileN - 1) // tileN

tg_num = tgN * tgM
# if all cu already active
if tg_num >= cu_num:
return 0
tilek = 256
split_max = (cu_num + tg_num - 1) // tg_num
# at least split = 2
for i in reversed(range(2, split_max + 1)):
if (model_dim % i == 0) and ((model_dim // i) % tilek == 0):
return i
return 0


cfg_2stages = None
# fmt: off
fused_moe_1stage_dict = {
Expand Down Expand Up @@ -551,6 +580,7 @@ def get_2stage_cfgs(
intermediate_pad,
bias1,
bias2,
token_real,
):
def get_cfg_2stages(tune_file):
import pandas as pd
Expand Down Expand Up @@ -620,8 +650,22 @@ def FinalFunc():
)
logger.info("\033[0m")

def use_cfg():
problem_type = (activation, dtype, q_dtype_a, q_dtype_w, q_type)
bypass_type = (
ActivationType.Silu,
dtypes.bf16,
dtypes.fp8,
dtypes.fp8,
QuantType.per_1x128,
)
if problem_type == bypass_type and (token_real * topk) <= 128: # bypass tuned
aiter.logger.info("bypass tuned results for fp8 blockscale")
return False
return True

# cfg = cfg_2stages.get(keys, None)
cfg = cfg_2stages.get(keys, None) if cfg_2stages else None
cfg = cfg_2stages.get(keys, None) if cfg_2stages and use_cfg() else None
if cfg is None and os.environ.get("AITER_ONLINE_TUNE", "0") == "1":
lock_path = os.path.join(bd_dir, f"lock_fmoe_tune_{keys}")
mp_lock(lock_path, MainFunc=MainFunc, FinalFunc=FinalFunc)
Expand All @@ -630,7 +674,7 @@ def FinalFunc():
cfg = cfg_2stages.get(keys, None) if cfg_2stages else None
if cfg is None:
logger.warning(f"Fmoe tuning not support for {keys}")
if cfg is None:
if cfg is None or int(os.environ.get("AITER_HEURISTIC_ONLY", "0")):
ksplit = 0
kernelName1 = ""
kernelName2 = ""
Expand All @@ -645,7 +689,7 @@ def FinalFunc():
doweight_stage1,
) in fused_moe_1stage_dict[get_gfx()]:
if q_type == QuantType.per_1x128:
run_1stage = True and (inter_dim % 256 == 0)
run_1stage = token > 32 and (inter_dim % 256 == 0)
elif q_type == QuantType.per_Token and q_dtype_w == dtypes.i8:
run_1stage = token > 32
elif q_type == QuantType.per_Token and q_dtype_w == dtypes.fp8:
Expand All @@ -657,11 +701,23 @@ def FinalFunc():
BLOCK_SIZE_M
if run_1stage
else (
64
(64 if token > 32 else 16)
if q_type == QuantType.per_1x128
else get_block_size_M(token, topk, expert, inter_dim)
)
)
ksplit = (
ksplit
if (run_1stage)
else (
get_ksplit(token_real, topk, expert, inter_dim, model_dim)
if q_type == QuantType.per_1x128
else ksplit
)
)
aiter.logger.info(
f"run_1stage = {run_1stage}, ksplit = {ksplit} q_type = {q_type}"
)
else:
block_m = cfg["block_m"]
ksplit = cfg["ksplit"]
Expand Down Expand Up @@ -717,14 +773,16 @@ def FinalFunc():
dtypes.fp16,
torch.uint32,
dtypes.fp4x2,
dtypes.fp8,
]
):
return MOEMetadata(
functools.partial(
aiter.ck_moe_stage1_fwd,
ck_moe_stage1,
kernelName=kernelName1,
activation=activation,
quant_type=q_type,
splitk=ksplit,
),
functools.partial(
aiter.ck_moe_stage2_fwd,
Expand Down Expand Up @@ -814,6 +872,7 @@ def fused_moe_2stages(
intermediate_pad,
bias1,
bias2,
token_num,
)
if (
quant_type == QuantType.per_1x32
Expand Down Expand Up @@ -1293,6 +1352,59 @@ def torch_moe_stage2(
return out.sum(1).to(dtype)


def ck_moe_stage1(
hidden_states,
w1, # [E, inter_dim*2, model_dim]
w2, # [E, model_dim, inter_dim]
sorted_token_ids, # [max_num_tokens_padded]
sorted_expert_ids, # [max_num_m_blocks]
num_valid_ids, # [1]
out,
topk,
block_m,
a1_scale,
w1_scale,
kernelName="",
sorted_weights=None,
quant_type=aiter.QuantType.No,
activation=ActivationType.Gelu,
splitk=1,
):
token_num = hidden_states.shape[0]
tmp_out = (
torch.zeros(
(token_num, topk, w1.shape[1]), dtype=dtypes.fp32, device=out.device
)
if splitk > 1
else out
)
aiter.ck_moe_stage1_fwd(
hidden_states,
w1,
w2,
sorted_token_ids,
sorted_expert_ids,
num_valid_ids,
tmp_out,
topk,
kernelName,
w1_scale,
a1_scale,
block_m,
sorted_weights,
quant_type,
activation,
splitk,
out.dtype,
)
if splitk > 1:
if activation == ActivationType.Silu:
aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32).to(out.dtype))
else:
aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32).to(out.dtype))
return out


def cktile_moe_stage1(
hidden_states,
w1,
Expand Down
23 changes: 21 additions & 2 deletions aiter/ops/moe_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,22 @@ def cmdGenFunc_ck_moe_stage(
sorted_weights: Optional[Tensor] = None,
quant_type: int = 0,
activation: int = 0,
splitk: int = 1,
dst_type: Optional[str] = None,
):

mul_routed_weight_stage = 2 if sorted_weights is None else 1
is_splitk = splitk > 1
outtype = str2dtype_dict[dst_type] if is_splitk else out.dtype
md_name, blob_gen_cmd = get_moe_stage_module(
hidden_states.dtype,
w1.dtype,
out.dtype,
outtype,
activation,
quant_type,
mul_routed_weight_stage,
getattr(w1, "is_shuffled", False),
is_splitk,
)
return {
"md_name": md_name,
Expand Down Expand Up @@ -292,6 +297,8 @@ def ck_moe_stage1(
sorted_weights: Optional[Tensor] = None,
quant_type: int = 0,
activation: int = 0,
splitk: int = 1,
dst_type: Optional[str] = None,
) -> None: ...


Expand Down Expand Up @@ -431,6 +438,11 @@ def moe_cktile2stages_gemm2(
torch.int4: "i4",
}

str2dtype_dict = {
"f16": dtypes.fp16,
"b16": dtypes.bf16,
}


@functools.lru_cache(maxsize=1024)
def get_moe_stage_module(
Expand All @@ -441,6 +453,7 @@ def get_moe_stage_module(
quant_type,
mul_routed_weight_stage,
preshuffle_mode=False,
is_splitk=False,
):
if isinstance(activation, int):
activation = ActivationType(activation)
Expand All @@ -455,6 +468,7 @@ def get_moe_stage_module(
if preshuffle_mode and weight_dtype == dtypes.fp4x2:
preshuffle_str = "--preshuffle"

splitk_str = "--issplitk" if is_splitk else ""
quant_type = (
QuantType.per_1x128 if quant_type == QuantType.per_128x128 else quant_type
)
Expand All @@ -471,10 +485,11 @@ def get_moe_stage_module(
act,
quant_type,
f"mulWeightStage{mul_routed_weight_stage}",
"splitk" if is_splitk else "",
]
)
blob_gen_cmd = [
f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} {preshuffle_str} -w {{}}"
f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} {preshuffle_str} {splitk_str} -w {{}}"
]

return md_name, blob_gen_cmd
Expand All @@ -496,6 +511,8 @@ def ck_moe_stage1_fwd(
sorted_weights: Optional[Tensor] = None,
quant_type: QuantType = QuantType.No,
activation: ActivationType = ActivationType.Silu,
splitk: Optional[int] = 1,
dst_type: Optional[torch.dtype] = None,
):
ck_moe_stage1(
hidden_states,
Expand All @@ -513,6 +530,8 @@ def ck_moe_stage1_fwd(
sorted_weights,
quant_type.value,
activation.value,
splitk,
dtype2str_dict[dst_type],
)
return out

Expand Down
24 changes: 18 additions & 6 deletions csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,23 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token
std::optional<int> block_m = 32,
std::optional<torch::Tensor> sorted_weights = std::nullopt,
int quant_type = 0,
int activation = 0)
int activation = 0,
int splitk = 1,
std::optional<std::string> dst_type = std::nullopt)
{
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(out));
at::hip::getCurrentHIPStream();

TORCH_CHECK(out.dtype() == at::ScalarType::BFloat16 || out.dtype() == at::ScalarType::Half,
"Out dtype only support BFloat16/Float16!")
if (splitk > 1)
{
TORCH_CHECK(out.dtype() == at::ScalarType::Float,
"Out dtype only support Float when splitk > 1!")
}
else
{
TORCH_CHECK(out.dtype() == at::ScalarType::BFloat16 || out.dtype() == at::ScalarType::Half,
"Out dtype only support BFloat16/Float16!")
}

int tokens = hidden_states.size(0);
int sorted_size = std::min(int64_t(tokens * topk * block_m.value()), sorted_token_ids.size(0));
Expand Down Expand Up @@ -99,7 +109,7 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token

kernel(at::hip::getCurrentHIPStream(),
tokens, sorted_size, N, K, topk,
hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr);
hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr, splitk);
}

void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token
Expand All @@ -116,7 +126,9 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token
std::optional<int> block_m = 32,
std::optional<torch::Tensor> sorted_weights = std::nullopt,
int quant_type = 0,
int activation = 0)
int activation = 0,
int splitk = 1,
std::optional<std::string> dst_type = std::nullopt)
{
TORCH_CHECK(out.dtype() == at::ScalarType::BFloat16 || out.dtype() == at::ScalarType::Half,
"Out dtype only support BFloat16/Float16!")
Expand Down Expand Up @@ -155,5 +167,5 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token

kernel(at::hip::getCurrentHIPStream(),
tokens, sorted_size, N, K, topk,
inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr);
inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr, splitk);
}
Loading
Loading