Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion .buildkite/test-amd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3004,9 +3004,43 @@ steps:
- vllm/_aiter_ops.py
- vllm/platforms/rocm.py
commands:
- pytest -v -s kernels/moe --ignore=kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
- pytest -v -s kernels/moe
--ignore=kernels/moe/test_modular_oai_triton_moe.py
--ignore=kernels/moe/test_gpt_oss_triton_kernels.py
--ignore=kernels/moe/test_moe.py
--ignore=kernels/moe/test_block_int8.py
--ignore=kernels/moe/test_triton_moe_no_act_mul.py
--ignore=kernels/moe/test_triton_moe_ptpc_fp8.py
--ignore=kernels/moe/test_deepep_moe.py
--ignore=kernels/moe/test_moe_layer.py
--shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
- pytest -v -s kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT

- label: Kernels FusedMoE Layer Test (2xB2002xMI355) # TBD
timeout_in_minutes: 180
mirror_hardwares: [amdexperimental, amdproduction, amdgfx950nightly, amdmi355]
agent_pool: mi355_2
num_gpus: 2
optional: true
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- csrc/moe/
- csrc/rocm/
- tests/kernels/moe
- vllm/model_executor/layers/fused_moe/
- vllm/model_executor/layers/quantization/
- vllm/distributed/
- vllm/config/
- vllm/forward_context.py
- vllm/v1/worker/workspace.py
- vllm/utils/import_utils.py
- vllm/utils/math_utils.py
- vllm/utils/torch_utils.py
- vllm/platforms/
- vllm/_aiter_ops.py
commands:
- pytest -v -s kernels/moe/test_moe_layer.py

- label: Kernels Quantization Test %N # TBD
timeout_in_minutes: 180
mirror_hardwares: [amdexperimental, amdproduction, amdgfx950nightly, amdmi355]
Expand All @@ -3023,11 +3057,33 @@ steps:
- vllm/model_executor/kernels/
commands:
- pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT

- label: Kernels FP8 MoE Test (2xH100-1xMI355) # TBD
timeout_in_minutes: 180
mirror_hardwares: [amdexperimental, amdproduction, amdgfx950nightly, amdmi355]
agent_pool: mi355_1
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- csrc/moe/
- vllm/model_executor/layers/fused_moe/
- tests/kernels/moe/test_deepep_moe.py
- vllm/_aiter_ops.py
- vllm/platforms/rocm.py
- vllm/envs.py
commands:
- pytest -v -s kernels/moe/test_gpt_oss_triton_kernels.py
- pytest -v -s kernels/moe/test_modular_oai_triton_moe.py
- pytest -v -s kernels/moe/test_moe.py
- pytest -v -s kernels/moe/test_block_int8.py
- pytest -v -s kernels/moe/test_triton_moe_no_act_mul.py
- pytest -v -s kernels/moe/test_triton_moe_ptpc_fp8.py

- label: Kernels FP8 MoE Test (2xH100-2xMI355) # TBD
timeout_in_minutes: 180
mirror_hardwares: [amdexperimental, amdproduction, amdgfx950nightly, amdmi355]
agent_pool: mi355_2
num_gpus: 2
optional: true
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- csrc/moe/
Expand Down
28 changes: 26 additions & 2 deletions tests/kernels/moe/test_modular_oai_triton_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import torch
import torch.nn.functional as F

from tests.utils import wait_for_gpu_memory_to_clear
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
Expand Down Expand Up @@ -35,6 +36,7 @@
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import set_random_seed

from .utils import make_dummy_moe_config, shuffle_weight
Expand Down Expand Up @@ -73,13 +75,30 @@ def make_weights(dtype, k, n, e):
w1_tri = shuffle_weight(w1_tri)
w1_bias_tri = shuffle_weight(w1_bias_tri)

if current_platform.is_rocm():
k_align, n2_align = 256, 512
else:
k_align, n2_align = 64, 128

w1_bottom_pad = round_up(w1_tri.shape[1], k_align) - w1_tri.shape[1]
w1_right_pad = round_up(w1_tri.shape[2], n2_align) - w1_tri.shape[2]
w2_bottom_pad = w1_right_pad // 2
w2_right_pad = w1_bottom_pad

w1_tri = F.pad(w1_tri, (0, w1_right_pad, 0, w1_bottom_pad, 0, 0))
w2_tri = F.pad(w2_tri, (0, w2_right_pad, 0, w2_bottom_pad, 0, 0))
w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0))
w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0))

# quant triton_weights
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, dtype, axis=1)
w1 = w1[..., :k, : 2 * n]
w1 = unshuffle_weight(w1)

w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, dtype, axis=1)
w2 = w2[..., :n, :k]

num_warps = 8
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
Expand Down Expand Up @@ -119,6 +138,7 @@ def make_weights(dtype, k, n, e):
w2_bias_tri,
w1_precision_config,
w2_precision_config,
w1_bottom_pad,
)


Expand Down Expand Up @@ -207,7 +227,7 @@ def oai_triton_moe_impl(


@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
not current_platform.is_cuda_alike(), reason="Requires CUDA-alike platform."
)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("m,n,k", MNK)
Expand All @@ -226,6 +246,7 @@ def test_oai_triton_moe(
):
wait_for_gpu_memory_to_clear(devices=[0], threshold_ratio=0.1)
set_random_seed(0)

(
w1,
w2,
Expand All @@ -237,9 +258,11 @@ def test_oai_triton_moe(
w2_bias_tri,
w1_precision_config,
w2_precision_config,
x_pad,
) = make_weights(dtype, k, n, num_experts)

x = torch.randn((m, k), dtype=dtype, device="cuda")
x_tri = F.pad(x, (0, x_pad, 0, 0))
router_logits = torch.randn(m, num_experts, device="cuda", dtype=dtype)
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
Expand All @@ -248,7 +271,7 @@ def test_oai_triton_moe(
out_ref = torch_moe_impl(x, w1, w2, w1_bias, w2_bias, topk_weights, topk_ids)

out = oai_triton_moe_impl(
x,
x_tri,
w1_tri,
w2_tri,
w1_precision_config,
Expand All @@ -260,5 +283,6 @@ def test_oai_triton_moe(
topk_ids,
unfused,
)
out = out[..., :k]

assert_close(ref=out_ref, tri=out, maxtol=0.025, rmstol=0.005)
Loading
Loading