Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fdd9cbd
cherry pick 09133e9833811778240b3c2cc4de2390fd08e470; and only add AI…
vllmellm Feb 26, 2025
668ec2f
cherry pick acc27ffa94e677b8f6fce0f5b593430ce6acbfe4; and only add AI…
vllmellm Mar 5, 2025
8d49d6e
bug fixes and pass unit tests
tjtanaa Mar 17, 2025
43af6c0
add AITER setup steps in Dockerfile.rocm_base
tjtanaa Mar 17, 2025
0c30ce9
remove AITER setup steps in Dockerfile.rocm
tjtanaa Mar 17, 2025
ab73f97
Merge remote-tracking branch 'origin/main' into aiter-linear
tjtanaa Mar 17, 2025
e952b2d
fix missing property from Platform
tjtanaa Mar 17, 2025
6a632ac
skip AITER in AMD CI
tjtanaa Mar 17, 2025
61c92a9
Merge remote-tracking branch 'origin/main' into aiter-linear
tjtanaa Mar 20, 2025
0224eff
merge with main
tjtanaa Apr 16, 2025
d2ed934
revert run-amd-test.sh; update Dockerfile.rocm_base aiter version, re…
tjtanaa Apr 16, 2025
3fec588
clean up spaces and newline; fix typo
tjtanaa Apr 16, 2025
3558099
clean up spaces and newline;
tjtanaa Apr 16, 2025
2bf7206
fix typo
tjtanaa Apr 16, 2025
1f979fa
untested refactoring
tjtanaa Apr 17, 2025
f13746c
fix bug; validated to work V1 AITER unquantized and quantized
tjtanaa Apr 19, 2025
20139af
relocate the linear helper function into aiter_ops and fix unittest
tjtanaa Apr 19, 2025
700ac73
add test_aiter_ops.py to unit test if the ops are registered correctl…
tjtanaa Apr 19, 2025
7dd2812
fix the test to test fake tensor implementation
tjtanaa Apr 20, 2025
d9f0e7b
use current_platform.fp8_dtype(); update aiter commit
tjtanaa Apr 21, 2025
dde9157
merge with main; fix dispatcher and unit tests
tjtanaa Apr 22, 2025
e34712c
remove is_rocm_aiter_xxxx_enabled flag from _aiter_ops.py
tjtanaa Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .buildkite/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ HF_CACHE="$(realpath ~)/huggingface"
mkdir -p "${HF_CACHE}"
HF_MOUNT="/root/.cache/huggingface"

# environment variables
SKIP_ROCM_ATIER_MODEL_TEST_CASES="True"
echo $SKIP_ROCM_ATIER_MODEL_TEST_CASES

commands=$@
echo "Commands:$commands"
#ignore certain kernels tests
Expand Down
14 changes: 14 additions & 0 deletions Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG AITER_BRANCH="e1ec015"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base

Expand Down Expand Up @@ -129,8 +131,18 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl

ARG AITER_REPO
ARG AITER_BRANCH
RUN git clone --recursive ${AITER_REPO}
RUN cd aiter \
&& git checkout ${AITER_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter

ARG BASE_IMAGE
ARG HIPBLASLT_BRANCH
ARG HIPBLAS_COMMON_BRANCH
ARG LEGACY_HIPBLASLT_OPTION
ARG RCCL_BRANCH
ARG RCCL_REPO
Expand All @@ -156,3 +168,5 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
20 changes: 20 additions & 0 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch.nn.functional as F

from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
dipsatch_unquantized_linear_func, rocm_aiter_tgemm_mm)
from vllm.platforms import current_platform


# Registered subclass for test
Expand Down Expand Up @@ -87,3 +91,19 @@ def test_enabled_ops_invalid(env: str):
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled()


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_linear", ["0", "1"])
def test_unquantized_linear_dispatch(use_rocm_aiter: str,
use_rocm_aiter_linear: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", use_rocm_aiter_linear)
linear_func = dipsatch_unquantized_linear_func()
print(f"use_rocm_aiter: {use_rocm_aiter}, " +
f"use_rocm_aiter_linear: {use_rocm_aiter_linear}")
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_linear):
assert linear_func == rocm_aiter_tgemm_mm
else:
assert linear_func == F.linear
24 changes: 14 additions & 10 deletions tests/models/decoder_only/language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@

Run `pytest tests/models/test_models.py`.
"""
import os

import pytest

from vllm.platforms import current_platform

from ...utils import check_logprobs_close

# These have unsupported head_dim for FA. We do not
Expand Down Expand Up @@ -69,16 +73,16 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
monkeypatch,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str, max_tokens: int, num_logprobs: int,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

if model in REQUIRES_V0:
monkeypatch.setenv("VLLM_USE_V1", "0")

Expand Down
16 changes: 14 additions & 2 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", MODELS)
@pytest.mark.parametrize("force_marlin", [False, True])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
monkeypatch) -> None:
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

Expand All @@ -47,7 +52,14 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch):
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
use_rocm_aiter: bool, monkeypatch):

if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
Expand Down
12 changes: 12 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
Expand Down Expand Up @@ -522,6 +524,16 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),

# use aiter ops unless specifically disabled
"VLLM_ROCM_USE_AITER":
lambda: (os.getenv("VLLM_USE_AITER", "False").lower() in ("true", "1")),

# use aiter linear op if aiter ops are enabled
"VLLM_ROCM_USE_AITER_LINEAR":
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1") and os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True"
).lower() in ("true", "1")),

# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,4 +1568,4 @@ def fused_moe(
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
block_shape=block_shape)
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,4 +889,4 @@ def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
mutates_args=[],
fake_impl=moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
)
)
18 changes: 15 additions & 3 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import itertools
from abc import abstractmethod
from typing import Any, Literal, Optional, Union
from typing import Any, Callable, Literal, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -26,6 +26,7 @@
RowvLLMParameter)
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

logger = init_logger(__name__)

Expand All @@ -50,6 +51,18 @@
]


def rocm_aiter_tgemm_mm(x: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor) -> torch.Tensor:
from aiter.tuned_gemm import tgemm
return tgemm.mm(x, weight, bias)


def dipsatch_unquantized_linear_func() -> Callable[..., torch.Tensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be good to specify the exact signature here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg We have fix the typo and specified the exact signature


def dispatch_unquantized_linear_func(
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
    from vllm._aiter_ops import is_rocm_aiter_linear_enabled
    if is_rocm_aiter_linear_enabled():
        return aiter_ops.rocm_aiter_tuned_gemm
    return F.linear

if current_platform.is_rocm_aiter_linear_enabled():
return rocm_aiter_tgemm_mm
return F.linear


def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
if marlin_tile_size is None:
Expand Down Expand Up @@ -187,8 +200,7 @@ def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

return F.linear(x, layer.weight, bias)
return dipsatch_unquantized_linear_func()(x, layer.weight, bias)


class LinearBase(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,4 +711,4 @@ class Fp8KVCacheMethod(BaseKVCacheMethod):
"""

def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config)
super().__init__(quant_config)
Loading