diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index a8c9e4094387..3a9a5bd40b41 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -3004,9 +3004,43 @@ steps: - vllm/_aiter_ops.py - vllm/platforms/rocm.py commands: - - pytest -v -s kernels/moe --ignore=kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + - pytest -v -s kernels/moe + --ignore=kernels/moe/test_modular_oai_triton_moe.py + --ignore=kernels/moe/test_gpt_oss_triton_kernels.py + --ignore=kernels/moe/test_moe.py + --ignore=kernels/moe/test_block_int8.py + --ignore=kernels/moe/test_triton_moe_no_act_mul.py + --ignore=kernels/moe/test_triton_moe_ptpc_fp8.py + --ignore=kernels/moe/test_deepep_moe.py + --ignore=kernels/moe/test_moe_layer.py + --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT - pytest -v -s kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT +- label: Kernels FusedMoE Layer Test (2xB2002xMI355) # TBD + timeout_in_minutes: 180 + mirror_hardwares: [amdexperimental, amdproduction, amdgfx950nightly, amdmi355] + agent_pool: mi355_2 + num_gpus: 2 + optional: true + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - csrc/moe/ + - csrc/rocm/ + - tests/kernels/moe + - vllm/model_executor/layers/fused_moe/ + - vllm/model_executor/layers/quantization/ + - vllm/distributed/ + - vllm/config/ + - vllm/forward_context.py + - vllm/v1/worker/workspace.py + - vllm/utils/import_utils.py + - vllm/utils/math_utils.py + - vllm/utils/torch_utils.py + - vllm/platforms/ + - vllm/_aiter_ops.py + commands: + - pytest -v -s kernels/moe/test_moe_layer.py + - label: Kernels Quantization Test %N # TBD timeout_in_minutes: 180 mirror_hardwares: [amdexperimental, amdproduction, amdgfx950nightly, amdmi355] @@ -3023,11 +3057,33 @@ steps: - vllm/model_executor/kernels/ commands: - pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + +- label: Kernels FP8 MoE Test (2xH100-1xMI355) # TBD + timeout_in_minutes: 180 + mirror_hardwares: [amdexperimental, amdproduction, amdgfx950nightly, amdmi355] + agent_pool: mi355_1 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - csrc/moe/ + - vllm/model_executor/layers/fused_moe/ + - tests/kernels/moe/test_deepep_moe.py + - vllm/_aiter_ops.py + - vllm/platforms/rocm.py + - vllm/envs.py + commands: + - pytest -v -s kernels/moe/test_gpt_oss_triton_kernels.py + - pytest -v -s kernels/moe/test_modular_oai_triton_moe.py + - pytest -v -s kernels/moe/test_moe.py + - pytest -v -s kernels/moe/test_block_int8.py + - pytest -v -s kernels/moe/test_triton_moe_no_act_mul.py + - pytest -v -s kernels/moe/test_triton_moe_ptpc_fp8.py - label: Kernels FP8 MoE Test (2xH100-2xMI355) # TBD timeout_in_minutes: 180 mirror_hardwares: [amdexperimental, amdproduction, amdgfx950nightly, amdmi355] agent_pool: mi355_2 + num_gpus: 2 + optional: true working_dir: "/vllm-workspace/tests" source_file_dependencies: - csrc/moe/ diff --git a/tests/kernels/moe/test_modular_oai_triton_moe.py b/tests/kernels/moe/test_modular_oai_triton_moe.py index 589d90d1eca7..b5ed2a167f48 100644 --- a/tests/kernels/moe/test_modular_oai_triton_moe.py +++ b/tests/kernels/moe/test_modular_oai_triton_moe.py @@ -6,6 +6,7 @@ import pytest import torch +import torch.nn.functional as F from tests.utils import wait_for_gpu_memory_to_clear from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -35,6 +36,7 @@ ) from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel from vllm.platforms import current_platform +from vllm.utils.math_utils import round_up from vllm.utils.torch_utils import set_random_seed from .utils import make_dummy_moe_config, shuffle_weight @@ -73,13 +75,30 @@ def make_weights(dtype, k, n, e): w1_tri = shuffle_weight(w1_tri) w1_bias_tri = shuffle_weight(w1_bias_tri) + if current_platform.is_rocm(): + k_align, n2_align = 256, 512 + else: + k_align, n2_align = 64, 128 + + w1_bottom_pad = round_up(w1_tri.shape[1], k_align) - w1_tri.shape[1] + w1_right_pad = round_up(w1_tri.shape[2], n2_align) - w1_tri.shape[2] + w2_bottom_pad = w1_right_pad // 2 + w2_right_pad = w1_bottom_pad + + w1_tri = F.pad(w1_tri, (0, w1_right_pad, 0, w1_bottom_pad, 0, 0)) + w2_tri = F.pad(w2_tri, (0, w2_right_pad, 0, w2_bottom_pad, 0, 0)) + w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0)) + w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0)) + # quant triton_weights w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1) w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, dtype, axis=1) + w1 = w1[..., :k, : 2 * n] w1 = unshuffle_weight(w1) w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1) w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, dtype, axis=1) + w2 = w2[..., :n, :k] num_warps = 8 w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) @@ -119,6 +138,7 @@ def make_weights(dtype, k, n, e): w2_bias_tri, w1_precision_config, w2_precision_config, + w1_bottom_pad, ) @@ -207,7 +227,7 @@ def oai_triton_moe_impl( @pytest.mark.skipif( - not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." + not current_platform.is_cuda_alike(), reason="Requires CUDA-alike platform." ) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("m,n,k", MNK) @@ -226,6 +246,7 @@ def test_oai_triton_moe( ): wait_for_gpu_memory_to_clear(devices=[0], threshold_ratio=0.1) set_random_seed(0) + ( w1, w2, @@ -237,9 +258,11 @@ def test_oai_triton_moe( w2_bias_tri, w1_precision_config, w2_precision_config, + x_pad, ) = make_weights(dtype, k, n, num_experts) x = torch.randn((m, k), dtype=dtype, device="cuda") + x_tri = F.pad(x, (0, x_pad, 0, 0)) router_logits = torch.randn(m, num_experts, device="cuda", dtype=dtype) topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True) topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) @@ -248,7 +271,7 @@ def test_oai_triton_moe( out_ref = torch_moe_impl(x, w1, w2, w1_bias, w2_bias, topk_weights, topk_ids) out = oai_triton_moe_impl( - x, + x_tri, w1_tri, w2_tri, w1_precision_config, @@ -260,5 +283,6 @@ def test_oai_triton_moe( topk_ids, unfused, ) + out = out[..., :k] assert_close(ref=out_ref, tri=out, maxtol=0.025, rmstol=0.005) diff --git a/tests/kernels/moe/test_moe_layer.py b/tests/kernels/moe/test_moe_layer.py index 89e28d950f9d..ac2020a80ba5 100644 --- a/tests/kernels/moe/test_moe_layer.py +++ b/tests/kernels/moe/test_moe_layer.py @@ -7,9 +7,11 @@ import functools import os +import tempfile import traceback import types from collections.abc import Callable +from contextlib import suppress from dataclasses import astuple, dataclass, fields from itertools import product from typing import get_args @@ -409,12 +411,10 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]: if config.use_gate and not config.use_shared_experts: return False, "gate requires shared_experts (use_overlapped mode)" - # Skip modelopt_fp4 if not on B100+ (compute capability 10.0+) - if ( - config.quantization == "modelopt_fp4" - and not current_platform.has_device_capability(100) + if config.quantization == "modelopt_fp4" and not ( + current_platform.is_rocm() or current_platform.has_device_capability(90) ): - return False, "modelopt_fp4 not supported on H100+ GPUs" + return False, "modelopt_fp4 requires native NVFP4 or emulation" # Skip flashinfer_nvlink if not on H100+ (compute capability 10.0+) if ( @@ -433,6 +433,17 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]: f"{config.backend} does not support quantization={config.quantization}", ) + if config.backend == "mori": + if os.environ.get("VLLM_TEST_ENABLE_MORI_MOE_LAYER") != "1": + return False, "mori MoE layer matrix is opt-in" + + from vllm._aiter_ops import rocm_aiter_ops + + if not rocm_aiter_ops.is_fused_moe_enabled(): + return False, "mori requires AITER fused MoE" + if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): + return False, "mori does not support AITER shared expert fusion" + if config.backend == "deepep_low_latency": from vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ll import ( # noqa: E501 DeepEPLLPrepareAndFinalize, @@ -1470,9 +1481,15 @@ def _run_one_config( @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("top_k", TOP_KS) @pytest.mark.parametrize("quantization", QUANT_METHODS) -@pytest.mark.parametrize("use_shared_experts", [False, True]) -@pytest.mark.parametrize("use_gate", [False, True]) -@pytest.mark.parametrize("use_routed_input_transform", [False, True]) +@pytest.mark.parametrize( + "use_routed_input_transform,use_gate,use_shared_experts", + [ + (False, False, False), + (False, False, True), + (False, True, True), + (True, False, True), + ], +) def test_moe_layer_no_parallel( m: int, n: int, @@ -1559,14 +1576,16 @@ def _parallel_worker( cpu_group, test_configs: list[MoETestConfig], verbosity: int, + failure_report_path: str | None = None, **kwargs, ) -> None: set_random_seed(7) + is_logging_rank = pgi.rank == 0 total = 0 passed = 0 failed = 0 - fail_ids = [] + failure_details = [] dp_rank = vllm_config.parallel_config.data_parallel_rank @@ -1581,9 +1600,11 @@ def _parallel_worker( tp_rank = pgi.rank % test_config.tp_size - if verbosity > 0: + if verbosity > 0 and is_logging_rank: print(f"subtest: {test_config.id()}", end="") + local_failed = False + local_error: str | None = None try: _run_one_config( vllm_config, @@ -1606,19 +1627,9 @@ def _parallel_worker( use_gate=test_config.use_gate, use_routed_input_transform=test_config.use_routed_input_transform, ) - if verbosity > 0: - print(" PASSED") - else: - print(".", end="") - passed = passed + 1 - except Exception as ex: - fail_ids.append(test_config.id()) - failed = failed + 1 - if verbosity > 0: - traceback.print_exc() - print(f"\n{str(ex)}\nFAILED") - else: - print("F", end="") + except Exception: + local_failed = True + local_error = traceback.format_exc() finally: # DeepEP managers are not reliably reusable across many subtests in # a single worker process. Tear them down after each DeepEP case so @@ -1634,6 +1645,40 @@ def _parallel_worker( total = total + 1 torch.distributed.barrier() + any_failed_tensor = torch.tensor( + [int(local_failed)], device=pgi.device, dtype=torch.int32 + ) + torch.distributed.all_reduce( + any_failed_tensor, op=torch.distributed.ReduceOp.MAX + ) + any_failed = bool(any_failed_tensor.item()) + + if any_failed: + failed = failed + 1 + + gathered_errors = [None] * pgi.world_size + torch.distributed.all_gather_object( + gathered_errors, local_error, group=cpu_group + ) + first_error = next( + (error for error in gathered_errors if error is not None), + "unknown distributed failure", + ) + assert first_error is not None + failure_details.append(f"{test_config.id()}\n{first_error.rstrip()}") + + if verbosity > 0 and is_logging_rank: + print(" FAILED") + print(first_error.rstrip()) + elif is_logging_rank: + print("F", end="") + else: + passed = passed + 1 + if verbosity > 0 and is_logging_rank: + print(" PASSED") + elif is_logging_rank: + print(".", end="") + skipped = total - (passed + failed) fails = f"{failed} failed" if failed > 0 else "" @@ -1643,17 +1688,24 @@ def _parallel_worker( passes = f"{sep}{passed} passed" if passed > 0 else "" report = ( - f"============= {fails}{skips}{passes} of {total} total tests =============" + f"============= {fails}{skips}{passes} of {total} total subcases =============" ) - sep = "\n" if verbosity == 0 else "" - print(f"{sep}{report}") + if is_logging_rank: + sep = "\n" if verbosity == 0 else "" + print(f"{sep}{report}") if failed > 0: - fail_ids_str = "\n".join(fail_ids) - raise RuntimeError( - f"\n============= Failed subtests =============\n{fail_ids_str}\n{report}" + failure_details_str = "\n\n".join(failure_details) + failure_report = ( + f"\n============= Failed subcases =============\n" + f"{failure_details_str}\n{report}" ) + if is_logging_rank and failure_report_path is not None: + with open(failure_report_path, "w", encoding="utf-8") as report_file: + report_file.write(failure_report) + if is_logging_rank: + raise RuntimeError(failure_report) # TODO: add cudagraphs/torch.compile tests @@ -1691,6 +1743,14 @@ def test_moe_layer( if os.environ.get("VLLM_LOGGING_LEVEL") is None: monkeypatch.setenv("VLLM_LOGGING_LEVEL", "ERROR") + if backend == "mori" and os.environ.get("VLLM_TEST_ENABLE_MORI_MOE_LAYER") == "1": + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_MOE", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "0") + from vllm._aiter_ops import rocm_aiter_ops + + rocm_aiter_ops.refresh_env_variables() + # TODO # VLLM_FLASHINFER_MOE_BACKEND=latency # VLLM_USE_FLASHINFER_MOE_FP16=1 @@ -1739,6 +1799,11 @@ def test_moe_layer( if len(test_configs) == 0: pytest.skip("No supported configs found for this testpoint.") + with tempfile.NamedTemporaryFile( + prefix="moe-layer-failures-", delete=False + ) as failure_report_file: + failure_report_path = failure_report_file.name + try: parallel_launch_with_config( world_size, @@ -1747,7 +1812,13 @@ def test_moe_layer( None, test_configs, verbosity, + failure_report_path=failure_report_path, ) + if os.path.getsize(failure_report_path) > 0: + with open(failure_report_path, encoding="utf-8") as report_file: + pytest.fail(report_file.read()) finally: + with suppress(FileNotFoundError): + os.remove(failure_report_path) torch.accelerator.synchronize() # TODO: Is this needed? torch.accelerator.empty_cache() diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index d4b2350f5c23..cd31eaf8bff9 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -31,6 +31,10 @@ ) from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( + ref_nvfp4_quant, +) +from vllm.platforms import current_platform from vllm.utils.deep_gemm import per_block_cast_to_fp8 from vllm.utils.math_utils import round_up @@ -289,13 +293,34 @@ def moe_quantize_weights_2d( assert not per_token_quant w_amax = torch.abs(w).max().to(torch.float32) w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax - w, w_s = ops.scaled_fp4_quant(w, w_gs) + if current_platform.is_rocm(): + w, w_s = _scaled_fp4_quant_emulated(w, w_gs) + else: + w, w_s = ops.scaled_fp4_quant(w, w_gs) else: raise RuntimeError(f"Unsupported quant type {quant_dtype}") return w, w_s, w_gs +def _pack_e2m1_fp4(fp4_values: torch.Tensor) -> torch.Tensor: + assert fp4_values.shape[-1] % 2 == 0 + + abs_values = fp4_values.abs() + codes = torch.empty_like(abs_values, dtype=torch.uint8) + for code, value in enumerate((0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0)): + codes[abs_values == value] = code + codes = codes | ((fp4_values < 0).to(torch.uint8) << 3) + return codes[..., 0::2] | (codes[..., 1::2] << 4) + + +def _scaled_fp4_quant_emulated( + w: torch.Tensor, w_gs: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + fp4_values, w_s = ref_nvfp4_quant(w, w_gs, block_size=16) + return _pack_e2m1_fp4(fp4_values), w_s.to(torch.float8_e4m3fn) + + def moe_quantize_weights( w: torch.Tensor, w_s: torch.Tensor | None, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 19fdb1ec884d..c0547b16c82d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -1143,4 +1143,7 @@ def process_fp8_input_tensor_strategy_moe( "for each layer." ) - return w13_input_scale.max(), w2_input_scale.max() + # Triton MoE kernels load tensor-wise activation scales through pointers. + # Keep the reduced scales as rank-1 tensors so they are not treated as + # compile-time scalar constants. + return w13_input_scale.max().reshape(1), w2_input_scale.max().reshape(1) diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index af5c6f2a7ab5..ab39e8cc59f7 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -33,7 +33,7 @@ def break_fp4_bytes(a, dtype): signs = (combined & 0x08).to(torch.bool) # Sign bits abs_vals = (combined & 0x07).to(torch.long) - kE2M1 = kE2M1ToFloat_handle.val + kE2M1 = kE2M1ToFloat_handle.val.to(device=a.device) # Device-aware lookup and sign application values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) # Reshape to final form