Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 2 additions & 22 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_fused_experts_func, dispatch_topk_func,
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
vllm_topk_softmax)
from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
vllm_topk_softmax)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.layernorm import (
Expand Down Expand Up @@ -111,24 +109,6 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
assert topk_func == vllm_topk_softmax


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("inplace", [True, False])
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
monkeypatch):

monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
is_rocm_aiter_moe_enabled.cache_clear()
fused_experts_func = dispatch_fused_experts_func(inplace)
if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts)
assert fused_experts_func == rocm_aiter_fused_experts
elif inplace:
assert fused_experts_func == torch_vllm_inplace_fused_experts
else:
assert fused_experts_func == torch_vllm_outplace_fused_experts


@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,9 +1100,6 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:


def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
return rocm_aiter_fused_experts
if inplace:
return torch_vllm_inplace_fused_experts
return torch_vllm_outplace_fused_experts
Expand Down
26 changes: 22 additions & 4 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore
if current_platform.is_rocm():
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
else:
rocm_aiter_fused_experts = None # type: ignore
if current_platform.is_tpu():
# the iterative moe implementation is used until the moe_pallas is fixed
from .moe_torch_iterative import fused_moe as fused_moe_pallas
Expand Down Expand Up @@ -119,10 +123,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)

self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

if self.rocm_aiter_moe_enabled:
# use 2stage ck moe layout
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
layer.w2_weight.data,
layout=(32, 32))

layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2
Expand Down Expand Up @@ -203,6 +211,16 @@ def forward_cuda(
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

if self.rocm_aiter_moe_enabled:
return rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)

return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
Expand Down
173 changes: 106 additions & 67 deletions vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
Expand All @@ -40,7 +40,7 @@ def rocm_aiter_asm_moe_tkw1_impl(
return asm_moe_tkw1(hidden_states,
w1,
w2,
topk_weight,
topk_weights,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
Expand All @@ -56,7 +56,7 @@ def rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
Expand All @@ -69,23 +69,6 @@ def rocm_aiter_asm_moe_tkw1_fake(
return torch.empty_like(hidden_states)


def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
from aiter import ck_moe
return ck_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids)


def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
return torch.empty_like(hidden_states)


def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
Expand Down Expand Up @@ -151,7 +134,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
Expand All @@ -174,7 +157,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weight,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
Expand All @@ -187,7 +170,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
Expand All @@ -198,6 +181,49 @@ def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
return torch.empty_like(hidden_states)


def rocm_aiter_ck_moe_2stages_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_size: Optional[list[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from aiter.fused_moe_bf16_asm import ck_moe_2stages
return ck_moe_2stages(a1=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_size=block_size,
expert_mask=expert_mask)


def rocm_aiter_ck_moe_2stages_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_size: Optional[list[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)


def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
Expand Down Expand Up @@ -226,14 +252,6 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_ck_moe",
op_func=rocm_aiter_ck_moe_impl,
mutates_args=[],
fake_impl=rocm_aiter_ck_moe_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1",
op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl,
Expand All @@ -250,6 +268,14 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_ck_moe_2stages",
op_func=rocm_aiter_ck_moe_2stages_impl,
mutates_args=[],
fake_impl=rocm_aiter_ck_moe_2stages_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=rocm_aiter_topk_softmax_impl,
Expand All @@ -259,29 +285,21 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
)


def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
def rocm_aiter_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> torch.Tensor:

from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
Expand All @@ -304,8 +322,8 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1])

return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1(
topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1,
w2, w1_scale, w2_scale, a1_scale, block_shape, None)
topk_ids, topk_weights, hidden_states.dtype, None, a1, w1, w2,
w1_scale, w2_scale, a1_scale, block_shape, None)

# w8a8 per-channel quantization
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
Expand All @@ -330,17 +348,36 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
fc2_smooth_scale=None,
a16=False,
per_tensor_quant_scale=None,
expert_mask=expert_map,
expert_mask=None,
activation_str=activation)

# w8a8 per-tensor activation per-tensor weight
elif use_fp8_w8a8:
assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is not supported for fp8_w8a8")

# - faster static per-tensor-activation static per-tensor-weight
# fp8 quantization w8a8
if a1_scale is not None and a2_scale is not None:
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)

# - fallback static per-tensor-activation static per-tensor-weight
# fp8 quantization w8a8
# - dynamic per-tensor activation static per-tensor-weight
# fp8 quantization w8a8
return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_weights=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
Expand All @@ -360,12 +397,12 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
topk_ids = topk_ids.to(torch.int32)
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)

# w16a16 fallback to rocm_aiter_ck_moe w16a16
return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids)
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids)


def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
Expand All @@ -379,7 +416,8 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
return topk_weights, topk_indices


def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
def shuffle_weights(*tensors: torch.Tensor,
layout: tuple[int, int]) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Expand All @@ -391,7 +429,8 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
A Tuple of shuffled tensors.
"""
from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor) for tensor in tensors)

return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)


def expand_weights(*tensors: torch.Tensor,
Expand All @@ -413,4 +452,4 @@ def expand_weights(*tensors: torch.Tensor,

return tuple(
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
for tensor, dim in zip(tensors, expansion_dims))
for tensor, dim in zip(tensors, expansion_dims))
Loading