Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4c296ae
add AITER in rocm docker base file
vllmellm Mar 17, 2025
8761424
add AITER fused moe kernels
vllmellm Mar 17, 2025
18e0717
add preprocessing steps required when using AITER moe kernels
vllmellm Mar 17, 2025
19b0cd2
add required ENV variables to enabled AITER ops
vllmellm Mar 17, 2025
38d5995
add test for fused moe dispatcher logic
vllmellm Mar 17, 2025
6028eab
bugfix: update aiter moe enable check
vllmellm Mar 17, 2025
fab94ea
add end to end model test when AITER ops are enabled for rocm
vllmellm Mar 17, 2025
8e419df
fix pre-commit errors
vllmellm Mar 17, 2025
d78a2ae
enable AITER for rocm platform in more tests
vllmellm Mar 17, 2025
06c92e6
enable AITER for rocm platform in related tests cases for fp8 quant
vllmellm Mar 17, 2025
8976e55
bugfix AITER block scaled moe wrong depency on a wrong envs variable
vllmellm Mar 18, 2025
8109aa0
Merge branch 'vllm-project:main' into aiter-fmoe-integration
vllmellm Mar 18, 2025
4d8d15b
separate out the moe kernels from aiter into different file
vllmellm Mar 18, 2025
4b942b7
Merge branch 'aiter-fmoe-integration' of https://github.com/EmbeddedL…
vllmellm Mar 18, 2025
c069a66
move AITER moe enability check from top of file into function level s…
vllmellm Mar 18, 2025
4047344
fix AITER Fused MoE distpatcher tests
vllmellm Mar 18, 2025
547464d
fix get envs variables in unit tests
vllmellm Mar 18, 2025
b9158ad
Merge remote-tracking branch 'origin/main' into aiter-fmoe-integration
tjtanaa Mar 18, 2025
fab7511
remove cascading logic from vllm.envs
vllmellm Mar 19, 2025
f7fffa0
move out the processing weights required for AITER MoE
vllmellm Mar 19, 2025
aa38d95
refactor aiter unit test flags into decorator
tjtanaa Mar 19, 2025
7d8707b
modify the rocm AITER check tests based on new decorator and include …
vllmellm Mar 19, 2025
fd36f6c
update run-amd-test.sh; fix skip rocm aiter test flag
tjtanaa Mar 19, 2025
0b55c4c
Merge remote-tracking branch 'origin/main' into aiter-fmoe-integration
vllmellm Mar 19, 2025
b8dd58a
bugfix topk softmax functions to return the tensors
vllmellm Mar 20, 2025
d2f86c0
remove unused tests for AITER MoE and keep only mixtral moe unit test
vllmellm Mar 20, 2025
3f230d7
Merge remote-tracking branch 'origin/main' into aiter-fmoe-integration
vllmellm Mar 20, 2025
91d0bda
Merge remote-tracking branch 'origin/main' into aiter-fmoe-integration
vllmellm Mar 24, 2025
05734e4
fix test cases in test_fp8.py to test AITER ops enability for load an…
vllmellm Mar 24, 2025
f242bf2
remove the extra line gaps and revert the test_phimoe.py to its origi…
vllmellm Mar 24, 2025
598dec9
Merge remote-tracking branch 'origin/main' into aiter-fmoe-integration
vllmellm Mar 26, 2025
61edbd4
match the VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE variable in envs t…
vllmellm Mar 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG AITER_BRANCH="e1ec015"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base

Expand Down Expand Up @@ -129,6 +131,15 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl

ARG AITER_REPO
ARG AITER_BRANCH
RUN git clone --recursive ${AITER_REPO}
RUN cd aiter \
&& git checkout ${AITER_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter

ARG BASE_IMAGE
ARG HIPBLASLT_BRANCH
ARG LEGACY_HIPBLASLT_OPTION
Expand Down Expand Up @@ -156,3 +167,5 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
27 changes: 22 additions & 5 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

Run `pytest tests/kernels/test_moe.py`.
"""
import os

import pytest
import torch
from transformers import MixtralConfig
Expand Down Expand Up @@ -202,11 +204,18 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,

@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
@torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype):
def test_mixtral_moe(dtype: torch.dtype, use_rocm_aiter: bool, monkeypatch):
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""

if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually like @DarkLight1337's feedback on #14959 to use pytest custom markers, instead of an environment variable, to selectively enable/disable these tests.

I assume we are disabling these because AITER isn't built in CI? If so we should change that :). I'm under the impression that CI just uses the Rocm dockerfile, which you've updated to include AITER, but I could be mistaken.

Copy link
Collaborator

@tjtanaa tjtanaa Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have tried to introduce the pytest.marker for use_rocm_aiter, in a minimal way.

# SPDX-License-Identifier: Apache-2.0
"""Compare the outputs of HF and vLLM when using greedy sampling.

Run `pytest tests/models/test_models.py`.
"""

import pytest

import os
from tests.utils import maybe_test_rocm_aiter
from vllm.platforms import current_platform

from ...utils import check_logprobs_close

# These have unsupported head_dim for FA. We do not
# not have a clean way to fall back, so we fail with
# a clear msg when it happens.
# https://github.com/vllm-project/vllm/issues/14524
REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]

# @maybe_test_rocm_aiter
@pytest.mark.parametrize(
    "model",
    [
        pytest.param(
            "bigscience/bloom-560m",  # bloom - testing alibi slopes
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
            "openai-community/gpt2",  # gpt2
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param("Milos/slovak-gpt-j-405M"),  # gptj
        pytest.param("bigcode/tiny_starcoder_py"),  # gpt_bigcode
        pytest.param("EleutherAI/pythia-70m"),  # gpt_neox
        pytest.param(
            "google/gemma-1.1-2b-it",  # gemma
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
            "THUDM/chatglm3-6b",  # chatglm (text-only)
        ),
        pytest.param(
            "meta-llama/Llama-3.2-1B-Instruct",  # llama
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
            "openbmb/MiniCPM3-4B",
            # fused_moe not supported on CPU
            marks=[pytest.mark.core_model],
        ),
        pytest.param(
            "facebook/opt-125m",  # opt
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
            "microsoft/phi-2",  # phi
            marks=[pytest.mark.core_model],
        ),
        pytest.param(
            "Qwen/Qwen-7B",  # qwen (text-only)
        ),
        pytest.param(
            "Qwen/Qwen2.5-0.5B-Instruct",  # qwen2
            marks=[pytest.mark.core_model],
        ),
        pytest.param("stabilityai/stablelm-3b-4e1t"),  # stablelm
        pytest.param("bigcode/starcoder2-3b"),  # starcoder2
        pytest.param(
            "ehristoforu/Falcon3-MoE-2x7B-Insruct",  # mixtral
            marks=[pytest.mark.cpu_model],
        )
    ])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize(
    "use_rocm_aiter", [
        pytest.param(
            True,
            marks=[pytest.mark.use_rocm_aiter],
        ),
        False
    ])
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
                dtype: str, max_tokens: int, num_logprobs: int,
                use_rocm_aiter: bool, monkeypatch) -> None:

    if model in REQUIRES_V0 or current_platform.is_rocm():
        monkeypatch.setenv("VLLM_USE_V1", "0")

    if use_rocm_aiter:
        if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
            pytest.skip("Skipping test suite for ROCM AITER")
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

    print(f"use_rocm_aiter: {use_rocm_aiter}")
    print("VLLM_ROCM_USE_AITER: ", os.getenv("VLLM_ROCM_USE_AITER", None))

    with hf_runner(model, dtype=dtype) as hf_model:
        if model.startswith("THUDM/chatglm3"):
            hf_model.model.get_output_embeddings = lambda: \
                hf_model.model.transformer.output_layer

        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)

    with vllm_runner(model, dtype=dtype) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )

Without changing the buildkite command: e.g. pytest -v -s models/decoder_only/language -m 'core_model or quant_model' from

- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
. It will always run with AITER.

@DarkLight1337 @SageMoore
Do you have a recommendation as to how should we use pytest marker without affecting the commands in buildkite?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore @DarkLight1337 Since we have been ensuring the unit tests passing on a particular AITER commit, we will enable the AITER kernel tests by default. In this case, we don't need to disable AITER. This also reduces the need to add pytest marker or any form of decorators.

The AITER commits are specified in the Dockerfile.rocm_base.

So, is it ok to keep it as follows?

...
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
                dtype: str, max_tokens: int, num_logprobs: int,
                use_rocm_aiter: bool, monkeypatch) -> None:

    if model in REQUIRES_V0 or current_platform.is_rocm():
        monkeypatch.setenv("VLLM_USE_V1", "0")

    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

    with hf_runner(model, dtype=dtype) as hf_model:
        if model.startswith("THUDM/chatglm3"):
            hf_model.model.get_output_embeddings = lambda: \
                hf_model.model.transformer.output_layer

        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)

    with vllm_runner(model, dtype=dtype) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )

pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# Instantiate our and huggingface's MoE blocks
config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
Expand Down Expand Up @@ -243,10 +252,18 @@ def test_mixtral_moe(dtype: torch.dtype):
torch.bfloat16: 1e-2,
}

torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])
if use_rocm_aiter:
# The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501
torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=0.01,
atol=100)
else:
torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])


@pytest.mark.parametrize("m", [1, 33, 64, 222])
Expand Down
37 changes: 37 additions & 0 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_fused_experts_func, dispatch_topk_func,
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
vllm_topk_softmax)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform


# Registered subclass for test
Expand Down Expand Up @@ -87,3 +92,35 @@ def test_enabled_ops_invalid(env: str):
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled()


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
topk_func = dispatch_topk_func()

if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_topk_softmax)

assert topk_func == rocm_aiter_topk_softmax
else:
assert topk_func == vllm_topk_softmax


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("inplace", [True, False])
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
monkeypatch):

monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
fused_experts_func = dispatch_fused_experts_func(inplace)
if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts)

assert fused_experts_func == rocm_aiter_fused_experts
elif inplace:
assert fused_experts_func == torch_vllm_inplace_fused_experts
else:
assert fused_experts_func == torch_vllm_outplace_fused_experts
81 changes: 49 additions & 32 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
"""
import copy
import json
import os

import jsonschema
import jsonschema.exceptions
import pytest

from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
MistralToolParser)
from vllm.platforms import current_platform
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

from ...utils import check_logprobs_close
Expand Down Expand Up @@ -174,15 +176,16 @@
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str, max_tokens: int, num_logprobs: int,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# TODO(sang): Sliding window should be tested separately.
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
Expand All @@ -206,14 +209,16 @@ def test_models(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_mistral_format(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str,
max_tokens: int, num_logprobs: int,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with vllm_runner(
model,
dtype=dtype,
Expand Down Expand Up @@ -244,11 +249,15 @@ def test_mistral_format(

@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_mistral_symbolic_languages(
vllm_runner,
model: str,
dtype: str,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with vllm_runner(model,
dtype=dtype,
max_model_len=8192,
Expand All @@ -266,11 +275,15 @@ def test_mistral_symbolic_languages(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("model",
MISTRAL_FORMAT_MODELS) # v1 can't do func calling
def test_mistral_function_calling(
vllm_runner,
model: str,
dtype: str,
) -> None:
@pytest.mark.parametrize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is disabled in CI and crashes for an unrelated reason when I try to run it locally. Let's hold off on adding an AITER case here until we reenable the test.

"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_mistral_function_calling(vllm_runner, model: str, dtype: str,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with vllm_runner(model,
dtype=dtype,
tokenizer_mode="mistral",
Expand Down Expand Up @@ -301,11 +314,15 @@ def test_mistral_function_calling(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding(
vllm_runner,
model: str,
guided_backend: str,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_mistral_guided_decoding(vllm_runner, model: str, guided_backend: str,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with vllm_runner(model, dtype='bfloat16',
tokenizer_mode="mistral") as vllm_model:

Expand Down
12 changes: 12 additions & 0 deletions tests/models/decoder_only/language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@

Run `pytest tests/models/test_models.py`.
"""
import os

import pytest

from vllm.platforms import current_platform

from ...utils import check_logprobs_close

# These have unsupported head_dim for FA. We do not
Expand Down Expand Up @@ -69,6 +73,8 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -77,8 +83,14 @@ def test_models(
dtype: str,
max_tokens: int,
num_logprobs: int,
use_rocm_aiter: bool,
monkeypatch,
) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

if model in REQUIRES_V0:
monkeypatch.setenv("VLLM_USE_V1", "0")

Expand Down
21 changes: 12 additions & 9 deletions tests/models/decoder_only/language/test_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

Run `pytest tests/models/test_phimoe.py`.
"""
import os

import pytest
import torch

Expand Down Expand Up @@ -79,15 +81,16 @@ def test_phimoe_routing_function():
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str, max_tokens: int, num_logprobs: int,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
Expand Down
22 changes: 19 additions & 3 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", MODELS)
@pytest.mark.parametrize("force_marlin", [False, True])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
monkeypatch) -> None:
use_rocm_aiter: bool, monkeypatch) -> None:
if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with vllm_runner(model_id) as llm:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
Expand All @@ -47,7 +52,13 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch):
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
use_rocm_aiter: bool, monkeypatch):
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
Expand Down Expand Up @@ -86,8 +97,13 @@ def check_model(model):
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("force_marlin", [False, True])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
monkeypatch) -> None:
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")

Expand Down
Loading