Skip to content
Merged
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="8970b25b"
ARG AITER_BRANCH="5a77249"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base
Expand Down
8 changes: 0 additions & 8 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
Expand Down Expand Up @@ -546,13 +545,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
("true", "1")),

# Whether to use aiter block scaled moe kernel.
# By default this is disabled.
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
lambda:
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
("true", "1")),

# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op

from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
rocm_aiter_fused_experts,
rocm_aiter_topk_softmax)
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled

logger = init_logger(__name__)

Expand Down Expand Up @@ -846,6 +844,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,

def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
return rocm_aiter_topk_softmax
return vllm_topk_softmax

Expand Down Expand Up @@ -1102,6 +1101,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:

def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
return rocm_aiter_fused_experts
if inplace:
return torch_vllm_inplace_fused_experts
Expand Down
140 changes: 107 additions & 33 deletions vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,68 @@
def is_rocm_aiter_moe_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_MOE \
and envs.VLLM_ROCM_USE_AITER \


def is_rocm_aiter_block_scaled_moe_enabled() -> bool:
return is_rocm_aiter_moe_enabled() and \
envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
and envs.VLLM_ROCM_USE_AITER


def rocm_aiter_asm_moe_tkw1(hidden_states,
w1,
w2,
topk_weight,
topk_ids,
fc1_scale=None,
fc2_scale=None,
fc1_smooth_scale=None,
fc2_smooth_scale=None,
a16=False,
per_tensor_quant_scale=None,
expert_mask=None,
activation_str: str = "silu") -> None:

from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe_tkw1

activation = \
ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu

return asm_moe_tkw1(hidden_states,
w1,
w2,
topk_weight,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
per_tensor_quant_scale=per_tensor_quant_scale,
expert_mask=expert_mask,
activation=activation)


def rocm_aiter_fused_experts(
*,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
apply_router_weight_on_input: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
**kwagrs # Ignore additional keyword arguments
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False,
) -> torch.Tensor:

import aiter as rocm_aiter
Expand All @@ -40,25 +80,21 @@ def rocm_aiter_fused_experts(
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)

if apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
# All AITER Fused MoE kernels are expecting the following datatypes
topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32)

hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
topk_ids = topk_ids.to(torch.int32)
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
if (block_shape is not None) and use_fp8_w8a8:
assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is not supported for block scaled moe"
)

if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
assert w1_scale is not None
assert w2_scale is not None

local_E = E = w1.shape[0]
if expert_mask is not None:
E = expert_mask.numel()
if expert_map is not None:
E = expert_map.numel()

topk = topk_ids.shape[1]
model_dim = w1.shape[-1]
Expand All @@ -80,7 +116,7 @@ def rocm_aiter_fused_experts(
E,
model_dim,
dtype,
expert_mask=expert_mask)
expert_mask=expert_map)

a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
rocm_aiter.fmoe_fp8_blockscale_g1u1(
Expand All @@ -102,7 +138,33 @@ def rocm_aiter_fused_experts(
)
return out_asm

elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer
# rather than the second FC.
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
assert topk_weights.shape[-1] == 1, (
"Only support topk=1 when"
" `apply_router_weight_on_input` is True")

return rocm_aiter_asm_moe_tkw1(hidden_states,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's assert apply_router_weight_on_input=True or do the if branch check when calling the _tkw1 kernel? btw, we should have some comments to illustrate the difference between _tkw1 kernel and other aiter kernels. The difference is on applying topk_weights on the output of the first GEMM or the second GEMM

w1,
w2,
topk_weights,
topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
fc1_smooth_scale=None,
fc2_smooth_scale=None,
a16=False,
per_tensor_quant_scale=None,
expert_mask=expert_map,
activation_str=activation)

elif use_fp8_w8a8:
assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is not supported for fp8_w8a8")
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
Expand All @@ -114,6 +176,18 @@ def rocm_aiter_fused_experts(
fc2_smooth_scale=None,
a16=False)

if apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"

hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
topk_ids = topk_ids.to(torch.int32)
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)

return rocm_aiter.ck_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,28 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)

from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)

# Property to determine if AITER is used
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
rocm_aiter_fused_experts, shuffle_weights)

# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)

layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)

self.fused_experts_func = rocm_aiter_fused_experts
else:
from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts

def apply(
self,
layer: torch.nn.Module,
Expand All @@ -268,7 +290,6 @@ def apply(
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
Expand All @@ -282,7 +303,7 @@ def apply(
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

return fused_experts(
return self.fused_experts_func(
x,
layer.w13_weight,
layer.w2_weight,
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,7 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
is_rocm_aiter_moe_enabled, shuffle_weights)
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)

# TODO (rob): refactor block quant into separate class.
if self.block_quant:
Expand All @@ -603,7 +602,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
requires_grad=False)
if is_rocm_aiter_block_scaled_moe_enabled():
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
Expand Down