Skip to content

Conversation

@charlifu
Copy link
Contributor

@charlifu charlifu commented Apr 10, 2025

This PR adds model restriction to aiter moe kernels. On rocm, only Mixtral and DeepSeek models will use aiter moe kernels.

  • Add cache decorator to is_rocm_aiter_moe_enabled()
  • Set vllm_config global value for _process_weights_after_loading function.
  • Use get_current_vllm_config to obtain the model name in the first time is_rocm_aiter_moe_enabled invoked.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Just one question.

f"checkpoint: {weights_not_loaded}")

_process_weights_after_loading(model, model_config, target_device)
with set_current_vllm_config(vllm_config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully appreciate the implications of setting the config around _process_weights_after_loading. Could you explain a bit why it's necessary?

Copy link
Contributor Author

@charlifu charlifu Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current_vllm_config value is not set when _process_weights_after_loading is being callled. Currently, this value is set only when creating the model class. So we do not have the information when determining whether to use aiter moe.

We definitely have other options.

  • Add model config parameter to _process_weights_after_loading and is_rocm_aiter_moe_enabled functions. We might have to add more cases to here.
  • Add a private member of model config to fused_moe layer class and set the value when creating the layer class.

Both require changing the interface. If you are concerned that exposing current vllm config during the execution of _process_weigths_after_loading could cause any potential issues. I prefer second option, since we call is_rocm_aiter_moe_enabled function during model execution as well.

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why restrict the AITER kernels to specific models? Is this a temporary performance hack or do you plan to maintain this list of models long term?

Comment on lines +20 to +21
model_cls_name = get_architecture_class_name(
get_current_vllm_config().model_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like is_rocm_aiter_moe_enabled is called during the actual forward pass, when (IIUC) it's not valid to call get_current_vllm_config(), as it's not set (e.g. in dispatch_fused_experts_func). That should be resolved before landing otherwise users will get spammed with warnings

Copy link
Contributor Author

@charlifu charlifu Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's why I add a cache decorator, it will use the cached value during the actually forward pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. But then is_rocm_aiter_moe_enabled would not be valid when instantiating a second LLM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yeah. You are right.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should just set the current vllm config during the forward pass in the model runner. @youkaichao WDYT?

@charlifu
Copy link
Contributor Author

Why restrict the AITER kernels to specific models? Is this a temporary performance hack or do you plan to maintain this list of models long term?

We plan to maintain this list long term: AITER assembly MoE is manually tuned per model characteristics. We would like to enable it only for models that have been fully vetted to avoid performance pitfalls

@charlifu
Copy link
Contributor Author

@SageMoore Thanks to @tlrmchlsmth, current way to make this change is not valid when creating a second LLM.

@ProExpertProg
Copy link
Collaborator

What if we did something similar to FpLinearOp where we initialize an object that dispatched at init time instead of forward time? I think we're trying to move all kernels in that direction anyway.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Apr 22, 2025

@charlifu @ProExpertProg @tlrmchlsmth I think the issue addressed by this PR is being resolved in this PR (#16727).

The AITER MoE kernel selection logic has been cleaned up and now it is determined based on the arguments passed into therocm_aiter_fused_experts. There is no need to maintain a list of support models

https://github.com/vllm-project/vllm/pull/16727/files#diff-033e75fc2b8c4fd797f15d1ef0d8b6079996e8f0f9b1bd50cdfe53bcc3c593bc

def rocm_aiter_fused_experts(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    inplace: bool = False,
    activation: str = "silu",
+    apply_router_weight_on_input: bool = False,
+    use_fp8_w8a8: bool = False,
    use_int8_w8a8: bool = False,
    use_int8_w8a16: bool = False,
    use_int4_w4a16: bool = False,
+    per_channel_quant: bool = False,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    w1_scale: Optional[torch.Tensor] = None,
    w2_scale: Optional[torch.Tensor] = None,
    w1_zp: Optional[torch.Tensor] = None,
    w2_zp: Optional[torch.Tensor] = None,
    a1_scale: Optional[torch.Tensor] = None,
    a2_scale: Optional[torch.Tensor] = None,
+    block_shape: Optional[List[int]] = None,
    allow_deep_gemm: bool = False,
) -> torch.Tensor:

    import aiter as rocm_aiter
    import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe

    from vllm.model_executor.layers.quantization.utils.fp8_utils import (
        per_token_group_quant_fp8)

    # All AITER Fused MoE kernels are expecting the following datatypes
    topk_weights = topk_weights.to(torch.float32)
    topk_ids = topk_ids.to(torch.int32)

+    if (block_shape is not None) and use_fp8_w8a8:
        assert not apply_router_weight_on_input, (
            "apply_router_weight_on_input is not supported for block scaled moe"
        )

        assert w1_scale is not None
        assert w2_scale is not None

        local_E = E = w1.shape[0]
        if expert_map is not None:
            E = expert_map.numel()

        topk = topk_ids.shape[1]
        model_dim = w1.shape[-1]
        dtype = hidden_states.dtype
        # The default block sizes are 128 in AITER.
        if block_shape is None:
            block_shape = [128, 128]

        scale_blk_k = block_shape[1]

        (
            sorted_token_ids,
            sorted_weight_buf,
            sorted_expert_ids,
            num_valid_ids,
            out_asm,
        ) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids,
                                               topk_weights,
                                               E,
                                               model_dim,
                                               dtype,
                                               expert_mask=expert_map)

        a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
        rocm_aiter.fmoe_fp8_blockscale_g1u1(
            out_asm,
            a1,
            w1,
            w2,
            sorted_token_ids,
            sorted_weight_buf,
            sorted_expert_ids,
            num_valid_ids,
            topk,
            w1_scale.view(local_E, -1),
            w2_scale.view(local_E, -1),
            a1_scale.t().contiguous(),
            block_shape[0],
            block_shape[1],
            None,
        )
        return out_asm

+    elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
        # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
        # This applies topk_weights on the GEMM output of the first FC layer
        #  rather than the second FC.
        assert (topk_weights.dim() == 2
                ), "`topk_weights` should be in shape (num_tokens, topk)"
        assert topk_weights.shape[-1] == 1, (
            "Only support topk=1 when"
            " `apply_router_weight_on_input` is True")

        return rocm_aiter_asm_moe_tkw1(hidden_states,
                                       w1,
                                       w2,
                                       topk_weights,
                                       topk_ids,
                                       fc1_scale=w1_scale,
                                       fc2_scale=w2_scale,
                                       fc1_smooth_scale=None,
                                       fc2_smooth_scale=None,
                                       a16=False,
                                       per_tensor_quant_scale=None,
                                       expert_mask=expert_map,
                                       activation_str=activation)

    elif use_fp8_w8a8:
        assert not apply_router_weight_on_input, (
            "apply_router_weight_on_input is not supported for fp8_w8a8")
        return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
                                           w1=w1,
                                           w2=w2,
                                           topk_weight=topk_weights,
                                           topk_ids=topk_ids,
                                           fc1_scale=w1_scale,
                                           fc2_scale=w2_scale,
                                           fc1_smooth_scale=None,
                                           fc2_smooth_scale=None,
                                           a16=False)

    if apply_router_weight_on_input:
        assert (topk_weights.dim() == 2
                ), "`topk_weights` should be in shape (num_tokens, topk)"
        _, topk = topk_weights.shape
        assert (
            topk == 1
        ), "Only support topk=1 when `apply_router_weight_on_input` is True"

        hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
        topk_ids = topk_ids.to(torch.int32)
        topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)

    return rocm_aiter.ck_moe(hidden_states=hidden_states,
                             w1=w1,
                             w2=w2,
                             topk_weights=topk_weights,
                             topk_ids=topk_ids)

@charlifu
Copy link
Contributor Author

@tjtanaa Thank you for the information. Will close this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants