Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_2STAGE_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
Expand Down Expand Up @@ -555,6 +556,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
("true", "1")),

# use aiter ck fused moe op if ater ops are enabled
"VLLM_ROCM_USE_AITER_2STAGE_MOE":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_2STAGE_MOE", "True").lower() in
("true", "1")),

# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,8 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
allow_deep_gemm: bool = False,
use_ck_moe_2stages: bool = False) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I doesn't look like you need this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. this is not needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the argument allow_deep_gemm. use_ck_moe_2stages is kept as there is an RFC that highlights Accessing envs.ENV is very costly. RFC Issue #17067 .
Thus, all the env are only invoked and stored as a property of the class during initialization stage.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To pass the pre-commit tests of file vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py , we have adjusted the logic of assignment of fused_experts function to fused_experts_func to become rocm_aiter_fused_experts_func, following the approach in vllm/attention/backends/rocm_flash_attn.py, where the attention functions are assigned to different property name: self.fa_attn_func , self.sdpa_attn_func and self.triton_attn_func

This also allows us to clean up the unused arguments of the function in rocm_aiter_fused_experts (vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py). Note that the expert_map has been removed as a bugfix, expert_map (integer ID of the experts) in vLLM and expert_mask (boolean mask of active experts on current GPU) in the AITER ops are different. The current rocm_aiter_fused_experts has removed the expert_map argument and encourage to add it back when enabling EP using AITER in future PR.

if (allow_deep_gemm and use_fp8_w8a8
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
assert apply_router_weight_on_input is False
Expand Down
34 changes: 30 additions & 4 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore
if current_platform.is_rocm():
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
else:
rocm_aiter_fused_experts = None # type: ignore
if current_platform.is_tpu():
# the iterative moe implementation is used until the moe_pallas is fixed
from .moe_torch_iterative import fused_moe as fused_moe_pallas
Expand Down Expand Up @@ -118,11 +122,19 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)
if is_rocm_aiter_moe_enabled():
is_rocm_aiter_2stage_moe_enabled, is_rocm_aiter_moe_enabled,
shuffle_weights)

self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_2stage_moe_enabled = is_rocm_aiter_2stage_moe_enabled()

if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layout = (32, 32) if self.rocm_aiter_2stage_moe_enabled else (16,
16)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
layer.w2_weight.data,
layout=layout)

layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2
Expand Down Expand Up @@ -203,6 +215,20 @@ def forward_cuda(
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

if self.rocm_aiter_moe_enabled:
return rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_ck_moe_2stages=self.rocm_aiter_2stage_moe_enabled)

return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
Expand Down
164 changes: 128 additions & 36 deletions vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@ def is_rocm_aiter_moe_enabled() -> bool:
and envs.VLLM_ROCM_USE_AITER


def is_rocm_aiter_2stage_moe_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_2STAGE_MOE \
and envs.VLLM_ROCM_USE_AITER_MOE \
and envs.VLLM_ROCM_USE_AITER
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems too much to check 3 environment variables. envs.VLLM_ROCM_USE_AITER_2STAGE_MOE is enough as it is only used when the other two are already true.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is_rocm_aiter_2stage_moe_enabled() has been removed since envs.VLLM_ROCM_USE_AITER_2STAGE_MOE is being called in the layer class during initialization only, not in the forward pass.



def rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
Expand All @@ -40,7 +47,7 @@ def rocm_aiter_asm_moe_tkw1_impl(
return asm_moe_tkw1(hidden_states,
w1,
w2,
topk_weight,
topk_weights,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
Expand All @@ -56,7 +63,7 @@ def rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -145,13 +152,13 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
block_shape: List[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:

return torch.empty_like(a1, dtype=torch.bf16)
return torch.empty_like(a1, dtype=hidden_states_dtype)


def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
Expand All @@ -174,7 +181,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weight,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
Expand All @@ -187,7 +194,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
Expand All @@ -198,6 +205,50 @@ def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
return torch.empty_like(hidden_states)


def rocm_aiter_ck_moe_2stages_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_size: Optional[List[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:

from aiter.fused_moe_bf16_asm import ck_moe_2stages
return ck_moe_2stages(a1=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_size=block_size,
expert_mask=expert_mask)


def rocm_aiter_ck_moe_2stages_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_size: Optional[List[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)


def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
Expand Down Expand Up @@ -250,6 +301,14 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_ck_moe_2stages",
op_func=rocm_aiter_ck_moe_2stages_impl,
mutates_args=[],
fake_impl=rocm_aiter_ck_moe_2stages_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=rocm_aiter_topk_softmax_impl,
Expand All @@ -259,29 +318,32 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
)


def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
def rocm_aiter_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False,
use_ck_moe_2stages: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of passing this boolean around, can you check the environment variable here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore There is an RFC that highlights Accessing envs.ENV is very costly. RFC Issue #17067 .
Thus, all the env are only invoked and stored as a property of the class during initialization stage.

) -> torch.Tensor:

from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
Expand Down Expand Up @@ -330,17 +392,36 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
fc2_smooth_scale=None,
a16=False,
per_tensor_quant_scale=None,
expert_mask=expert_map,
expert_mask=None,
activation_str=activation)

# w8a8 per-tensor activation per-tensor weight
elif use_fp8_w8a8:
assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is not supported for fp8_w8a8")

# - faster static per-tensor-activation static per-tensor-weight
# fp8 quantization w8a8
if use_ck_moe_2stages and a1_scale is not None and a2_scale is not None:
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)

# - fallback static per-tensor-activation static per-tensor-weight
# fp8 quantization w8a8
# - dynamic per-tensor activation static per-tensor-weight
# fp8 quantization w8a8
return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_weights=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
Expand All @@ -360,6 +441,15 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
topk_ids = topk_ids.to(torch.int32)
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)

# faster w16a16
if use_ck_moe_2stages:
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids)

# w16a16 fallback to rocm_aiter_ck_moe w16a16
return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states,
w1=w1,
Expand All @@ -379,7 +469,8 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
return topk_weights, topk_indices


def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
def shuffle_weights(*tensors: torch.Tensor,
layout: tuple[int, int]) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Expand All @@ -391,11 +482,12 @@ def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
A Tuple of shuffled tensors.
"""
from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor) for tensor in tensors)

return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)


def expand_weights(*tensors: torch.Tensor,
expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]:
expansion_dims: List[int]) -> Tuple[torch.Tensor, ...]:
"""
Expands the dimensions of input tensors.

Expand All @@ -413,4 +505,4 @@ def expand_weights(*tensors: torch.Tensor,

return tuple(
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
for tensor, dim in zip(tensors, expansion_dims))
for tensor, dim in zip(tensors, expansion_dims))
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
rocm_aiter_fused_experts, shuffle_weights)

# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
layer.w2_weight.data,
layout=(16, 16))

layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
Expand Down
Loading