Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 2 additions & 22 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_fused_experts_func, dispatch_topk_func,
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
vllm_topk_softmax)
from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
vllm_topk_softmax)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.layernorm import (
Expand Down Expand Up @@ -111,24 +109,6 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
assert topk_func == vllm_topk_softmax


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("inplace", [True, False])
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
monkeypatch):

monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
is_rocm_aiter_moe_enabled.cache_clear()
fused_experts_func = dispatch_fused_experts_func(inplace)
if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts)
assert fused_experts_func == rocm_aiter_fused_experts
elif inplace:
assert fused_experts_func == torch_vllm_inplace_fused_experts
else:
assert fused_experts_func == torch_vllm_outplace_fused_experts


@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_2STAGE_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
Expand Down Expand Up @@ -557,6 +558,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
("true", "1")),

# use aiter ck fused moe op if ater ops are enabled
"VLLM_ROCM_USE_AITER_2STAGE_MOE":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_2STAGE_MOE", "True").lower() in
("true", "1")),

# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,9 +1100,6 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:


def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
return rocm_aiter_fused_experts
if inplace:
return torch_vllm_inplace_fused_experts
return torch_vllm_outplace_fused_experts
Expand Down Expand Up @@ -1130,7 +1127,8 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
allow_deep_gemm: bool = False,
use_ck_moe_2stages: bool = False) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I doesn't look like you need this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. this is not needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the argument allow_deep_gemm. use_ck_moe_2stages is kept as there is an RFC that highlights Accessing envs.ENV is very costly. RFC Issue #17067 .
Thus, all the env are only invoked and stored as a property of the class during initialization stage.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To pass the pre-commit tests of file vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py , we have adjusted the logic of assignment of fused_experts function to fused_experts_func to become rocm_aiter_fused_experts_func, following the approach in vllm/attention/backends/rocm_flash_attn.py, where the attention functions are assigned to different property name: self.fa_attn_func , self.sdpa_attn_func and self.triton_attn_func

This also allows us to clean up the unused arguments of the function in rocm_aiter_fused_experts (vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py). Note that the expert_map has been removed as a bugfix, expert_map (integer ID of the experts) in vLLM and expert_mask (boolean mask of active experts on current GPU) in the AITER ops are different. The current rocm_aiter_fused_experts has removed the expert_map argument and encourage to add it back when enabling EP using AITER in future PR.

if (allow_deep_gemm and use_fp8_w8a8
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
assert apply_router_weight_on_input is False
Expand Down
28 changes: 25 additions & 3 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore
if current_platform.is_rocm():
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
else:
rocm_aiter_fused_experts = None # type: ignore
if current_platform.is_tpu():
# the iterative moe implementation is used until the moe_pallas is fixed
from .moe_torch_iterative import fused_moe as fused_moe_pallas
Expand Down Expand Up @@ -119,10 +123,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)
if is_rocm_aiter_moe_enabled():

self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_2stage_moe_enabled = envs.VLLM_ROCM_USE_AITER_2STAGE_MOE

if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layout = (32, 32) if self.rocm_aiter_2stage_moe_enabled else (16,
16)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
layer.w2_weight.data,
layout=layout)

layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2
Expand Down Expand Up @@ -203,6 +214,17 @@ def forward_cuda(
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

if self.rocm_aiter_moe_enabled:
return rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_ck_moe_2stages=self.rocm_aiter_2stage_moe_enabled)

return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
Expand Down
149 changes: 112 additions & 37 deletions vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
Expand All @@ -40,7 +40,7 @@ def rocm_aiter_asm_moe_tkw1_impl(
return asm_moe_tkw1(hidden_states,
w1,
w2,
topk_weight,
topk_weights,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
Expand All @@ -56,7 +56,7 @@ def rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -151,7 +151,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
Expand All @@ -174,7 +174,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weight,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
Expand All @@ -187,7 +187,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
Expand All @@ -198,6 +198,49 @@ def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
return torch.empty_like(hidden_states)


def rocm_aiter_ck_moe_2stages_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_size: Optional[List[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from aiter.fused_moe_bf16_asm import ck_moe_2stages
return ck_moe_2stages(a1=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_size=block_size,
expert_mask=expert_mask)


def rocm_aiter_ck_moe_2stages_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_size: Optional[List[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)


def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
Expand Down Expand Up @@ -250,6 +293,14 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_ck_moe_2stages",
op_func=rocm_aiter_ck_moe_2stages_impl,
mutates_args=[],
fake_impl=rocm_aiter_ck_moe_2stages_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=rocm_aiter_topk_softmax_impl,
Expand All @@ -259,29 +310,23 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
)


def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
def rocm_aiter_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_ck_moe_2stages: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of passing this boolean around, can you check the environment variable here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore There is an RFC that highlights Accessing envs.ENV is very costly. RFC Issue #17067 .
Thus, all the env are only invoked and stored as a property of the class during initialization stage.

) -> torch.Tensor:

from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
Expand All @@ -304,8 +349,8 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1])

return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1(
topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1,
w2, w1_scale, w2_scale, a1_scale, block_shape, None)
topk_ids, topk_weights, hidden_states.dtype, None, a1, w1, w2,
w1_scale, w2_scale, a1_scale, block_shape, None)

# w8a8 per-channel quantization
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
Expand All @@ -330,17 +375,36 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
fc2_smooth_scale=None,
a16=False,
per_tensor_quant_scale=None,
expert_mask=expert_map,
expert_mask=None,
activation_str=activation)

# w8a8 per-tensor activation per-tensor weight
elif use_fp8_w8a8:
assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is not supported for fp8_w8a8")

# - faster static per-tensor-activation static per-tensor-weight
# fp8 quantization w8a8
if use_ck_moe_2stages and a1_scale is not None and a2_scale is not None:
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)

# - fallback static per-tensor-activation static per-tensor-weight
# fp8 quantization w8a8
# - dynamic per-tensor activation static per-tensor-weight
# fp8 quantization w8a8
return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_weights=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
Expand All @@ -360,6 +424,15 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
topk_ids = topk_ids.to(torch.int32)
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)

# faster w16a16
if use_ck_moe_2stages:
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids)

# w16a16 fallback to rocm_aiter_ck_moe w16a16
return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states,
w1=w1,
Expand All @@ -379,7 +452,8 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
return topk_weights, topk_indices


def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
def shuffle_weights(*tensors: torch.Tensor,
layout: tuple[int, int]) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Expand All @@ -391,11 +465,12 @@ def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
A Tuple of shuffled tensors.
"""
from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor) for tensor in tensors)

return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)


def expand_weights(*tensors: torch.Tensor,
expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]:
expansion_dims: List[int]) -> Tuple[torch.Tensor, ...]:
"""
Expands the dimensions of input tensors.

Expand All @@ -413,4 +488,4 @@ def expand_weights(*tensors: torch.Tensor,

return tuple(
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
for tensor, dim in zip(tensors, expansion_dims))
for tensor, dim in zip(tensors, expansion_dims))
Loading