Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4c296ae
add AITER in rocm docker base file
vllmellm Mar 17, 2025
8761424
add AITER fused moe kernels
vllmellm Mar 17, 2025
18e0717
add preprocessing steps required when using AITER moe kernels
vllmellm Mar 17, 2025
19b0cd2
add required ENV variables to enabled AITER ops
vllmellm Mar 17, 2025
38d5995
add test for fused moe dispatcher logic
vllmellm Mar 17, 2025
6028eab
bugfix: update aiter moe enable check
vllmellm Mar 17, 2025
fab94ea
add end to end model test when AITER ops are enabled for rocm
vllmellm Mar 17, 2025
8e419df
fix pre-commit errors
vllmellm Mar 17, 2025
d78a2ae
enable AITER for rocm platform in more tests
vllmellm Mar 17, 2025
06c92e6
enable AITER for rocm platform in related tests cases for fp8 quant
vllmellm Mar 17, 2025
8976e55
bugfix AITER block scaled moe wrong depency on a wrong envs variable
vllmellm Mar 18, 2025
8109aa0
Merge branch 'vllm-project:main' into aiter-fmoe-integration
vllmellm Mar 18, 2025
4d8d15b
separate out the moe kernels from aiter into different file
vllmellm Mar 18, 2025
4b942b7
Merge branch 'aiter-fmoe-integration' of https://github.com/EmbeddedL…
vllmellm Mar 18, 2025
c069a66
move AITER moe enability check from top of file into function level s…
vllmellm Mar 18, 2025
4047344
fix AITER Fused MoE distpatcher tests
vllmellm Mar 18, 2025
547464d
fix get envs variables in unit tests
vllmellm Mar 18, 2025
b9158ad
Merge remote-tracking branch 'origin/main' into aiter-fmoe-integration
tjtanaa Mar 18, 2025
fab7511
remove cascading logic from vllm.envs
vllmellm Mar 19, 2025
f7fffa0
move out the processing weights required for AITER MoE
vllmellm Mar 19, 2025
aa38d95
refactor aiter unit test flags into decorator
tjtanaa Mar 19, 2025
7d8707b
modify the rocm AITER check tests based on new decorator and include …
vllmellm Mar 19, 2025
fd36f6c
update run-amd-test.sh; fix skip rocm aiter test flag
tjtanaa Mar 19, 2025
0b55c4c
Merge remote-tracking branch 'origin/main' into aiter-fmoe-integration
vllmellm Mar 19, 2025
b8dd58a
bugfix topk softmax functions to return the tensors
vllmellm Mar 20, 2025
d2f86c0
remove unused tests for AITER MoE and keep only mixtral moe unit test
vllmellm Mar 20, 2025
3f230d7
Merge remote-tracking branch 'origin/main' into aiter-fmoe-integration
vllmellm Mar 20, 2025
91d0bda
Merge remote-tracking branch 'origin/main' into aiter-fmoe-integration
vllmellm Mar 24, 2025
05734e4
fix test cases in test_fp8.py to test AITER ops enability for load an…
vllmellm Mar 24, 2025
f242bf2
remove the extra line gaps and revert the test_phimoe.py to its origi…
vllmellm Mar 24, 2025
598dec9
Merge remote-tracking branch 'origin/main' into aiter-fmoe-integration
vllmellm Mar 26, 2025
61edbd4
match the VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE variable in envs t…
vllmellm Mar 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,16 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,

@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
@torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype):
def test_mixtral_moe(dtype: torch.dtype, use_rocm_aiter: bool, monkeypatch):
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""

if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# Instantiate our and huggingface's MoE blocks
config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
Expand Down Expand Up @@ -243,10 +248,18 @@ def test_mixtral_moe(dtype: torch.dtype):
torch.bfloat16: 1e-2,
}

torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])
if use_rocm_aiter:
# The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501
torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=0.01,
atol=100)
else:
torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])


@pytest.mark.parametrize("m", [1, 33, 64, 222])
Expand Down
36 changes: 36 additions & 0 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_fused_experts_func, dispatch_topk_func,
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
vllm_topk_softmax)
from vllm.model_executor.layers.layernorm import (
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
Expand Down Expand Up @@ -92,6 +96,38 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled()


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
topk_func = dispatch_topk_func()

if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_topk_softmax)

assert topk_func == rocm_aiter_topk_softmax
else:
assert topk_func == vllm_topk_softmax


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("inplace", [True, False])
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
monkeypatch):

monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
fused_experts_func = dispatch_fused_experts_func(inplace)
if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts)

assert fused_experts_func == rocm_aiter_fused_experts
elif inplace:
assert fused_experts_func == torch_vllm_inplace_fused_experts
else:
assert fused_experts_func == torch_vllm_outplace_fused_experts


@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
Expand Down
41 changes: 9 additions & 32 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,8 @@
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
# TODO(sang): Sliding window should be tested separately.
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
Expand All @@ -206,14 +199,8 @@ def test_models(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_mistral_format(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str,
max_tokens: int, num_logprobs: int) -> None:
with vllm_runner(
model,
dtype=dtype,
Expand Down Expand Up @@ -244,11 +231,8 @@ def test_mistral_format(

@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_mistral_symbolic_languages(
vllm_runner,
model: str,
dtype: str,
) -> None:
def test_mistral_symbolic_languages(vllm_runner, model: str,
dtype: str) -> None:
with vllm_runner(model,
dtype=dtype,
max_model_len=8192,
Expand All @@ -266,11 +250,7 @@ def test_mistral_symbolic_languages(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("model",
MISTRAL_FORMAT_MODELS) # v1 can't do func calling
def test_mistral_function_calling(
vllm_runner,
model: str,
dtype: str,
) -> None:
def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
with vllm_runner(model,
dtype=dtype,
tokenizer_mode="mistral",
Expand Down Expand Up @@ -301,11 +281,8 @@ def test_mistral_function_calling(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding(
vllm_runner,
model: str,
guided_backend: str,
) -> None:
def test_mistral_guided_decoding(vllm_runner, model: str,
guided_backend: str) -> None:
with vllm_runner(model, dtype='bfloat16',
tokenizer_mode="mistral") as vllm_model:

Expand Down
23 changes: 20 additions & 3 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", MODELS)
@pytest.mark.parametrize("force_marlin", [False, True])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
monkeypatch) -> None:
use_rocm_aiter: bool, monkeypatch) -> None:

if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

Expand All @@ -47,7 +53,13 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch):
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
use_rocm_aiter: bool, monkeypatch):
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
Expand Down Expand Up @@ -86,8 +98,13 @@ def check_model(model):
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("force_marlin", [False, True])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
monkeypatch) -> None:
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")

Expand Down
15 changes: 15 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
Expand Down Expand Up @@ -511,6 +513,19 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")),

# Whether to use aiter moe ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MOE":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
("true", "1")),

# Whether to use aiter block scaled moe kernel.
# By default this is disabled.
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
lambda:
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
("true", "1")),

# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
Expand Down
92 changes: 69 additions & 23 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op

from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
rocm_aiter_fused_experts,
rocm_aiter_topk_softmax)

logger = init_logger(__name__)


Expand Down Expand Up @@ -1035,6 +1039,28 @@ def try_get_optimal_moe_config(
return config


def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

return topk_weights, topk_indices


def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled():
return rocm_aiter_topk_softmax
return vllm_topk_softmax


def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
Expand All @@ -1059,17 +1085,14 @@ def fused_topk(
dtype=torch.int32,
device=hidden_states.device)

ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_func = dispatch_topk_func()
topk_weights, topk_ids = topk_func(topk_weights, topk_ids,
token_expert_indicies,
gating_output_float, renormalize)

del token_expert_indicies # Not used. Will be used in the future.
return topk_weights, topk_ids


Expand Down Expand Up @@ -1259,6 +1282,24 @@ def outplace_fused_experts_fake(
)


def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
torch.ops.vllm.inplace_fused_experts(**kwargs)
hidden_states = kwargs['hidden_states']
return hidden_states


def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
return torch.ops.vllm.outplace_fused_experts(**kwargs)


def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if is_rocm_aiter_moe_enabled():
return rocm_aiter_fused_experts
if inplace:
return torch_vllm_inplace_fused_experts
return torch_vllm_outplace_fused_experts


def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
Expand All @@ -1278,20 +1319,25 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:

if inplace:
torch.ops.vllm.inplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
return hidden_states
else:
return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
return dispatch_fused_experts_func(inplace)(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)


def fused_experts_impl(hidden_states: torch.Tensor,
Expand Down
Loading