Skip to content

Conversation

@TianQiLin666666
Copy link
Collaborator

Follow #7278, support cutlass_moe_fp8 for deepepmoe layer.

Motivation

Modifications

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @TianQiLin666666, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces experimental support for 8-bit floating point (FP8) quantization within the DeepEP Mixture of Experts (MoE) layer. My goal is to leverage CUTLASS kernels to provide a highly optimized execution path for MoE computations, potentially leading to significant performance improvements and reduced memory footprint on compatible NVIDIA GPUs.

Highlights

  • New FP8 MoE Kernel: I've added a new function cutlass_moe_fp8 in python/sglang/srt/layers/moe/cutlass_moe.py. This function is designed to perform efficient Mixture of Experts (MoE) computations using 8-bit floating point (FP8) quantization, leveraging CUTLASS-like kernels. It supports per-block FP8 quantized weights and per-token-group FP8 quantized activations.
  • DeepEP MoE Layer Integration: The DeepEPMoELayer in python/sglang/srt/layers/moe/ep_moe/layer.py has been updated to integrate this new FP8 MoE path. A new forward_cutlass_moe method orchestrates the FP8 computation flow, including input scattering, two calls to cutlass_moe_fp8 for the expert layers, activation, and output gathering.
  • Conditional Execution and Compatibility: The new FP8 MoE path is conditionally enabled. It requires the SGLANG_CUTLASS_MOE environment variable to be set and performs checks for hardware compatibility, specifically requiring NVIDIA GPUs with Compute Capability 9.0 (Sm90) or higher and CUDA version 12.3 or newer. This ensures the optimized kernel is only used on supported configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for cutlass_moe_fp8 for the deepepmoe layer, which is a significant performance enhancement for FP8-capable hardware. The changes are well-structured, adding a new cutlass_moe_fp8 kernel wrapper and integrating it into the DeepEPMoE forward pass.

I've identified a few critical issues related to correctness and memory management that need to be addressed, along with some suggestions for improving code clarity. Specifically, there's a potential for an OOM error due to a large hardcoded workspace allocation, and a correctness issue with how m_indices is being passed to the new kernel. Please review the detailed comments.

out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Allocating a 1GB workspace tensor with torch.empty is risky and can easily lead to Out-of-Memory (OOM) errors, especially in environments with limited memory. This size is hardcoded and might be excessive for the actual needs of the fp8_blockwise_scaled_grouped_mm kernel.

Consider calculating the required workspace size dynamically based on the problem size or using a much smaller, more reasonable default size. For reference, the test file test_cutlass_moe.py allocates a workspace of about 7MB, which is significantly smaller. A smaller buffer or dynamic allocation would be safer and more memory-efficient.

w=self.w13_weight_fp8[0],
w_scale=self.w13_weight_fp8[1],
c=gateup_output,
m_indices=m_indices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The m_indices tensor passed to cutlass_moe_fp8 is incorrect. The cutlass_moe_fp8 function expects m_indices to be the cumulative sum of token counts per expert to calculate problem sizes. However, the m_indices tensor passed here is the output of ep_scatter, which contains expert IDs for each token.

This will result in incorrect calculations within cutlass_moe_fp8 and likely lead to errors or wrong results.

You should compute the cumulative sum of tokens from num_recv_tokens_per_expert_gpu and pass that to cutlass_moe_fp8. For example:

m_indices_for_cutlass = torch.nn.functional.pad(
    torch.cumsum(num_recv_tokens_per_expert_gpu, dim=0, dtype=torch.int32), (1, 0)
)

Then, pass m_indices_for_cutlass to the function. The same applies to the second call to cutlass_moe_fp8.

        m_indices = torch.empty(
            all_tokens, device=hidden_states_fp8_device, dtype=torch.int32
        )
        output_index = torch.empty_like(topk_idx)

        num_recv_tokens_per_expert_gpu = torch.tensor(
            num_recv_tokens_per_expert,
            dtype=torch.int32,
            pin_memory=True,
            device="cpu",
        ).cuda(non_blocking=True)
        expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)

        ep_scatter(
            hidden_states_fp8,
            hidden_states_scale,
            topk_idx,
            num_recv_tokens_per_expert_gpu,
            expert_start_loc,
            gateup_input_fp8,
            gateup_input_scale,
            m_indices,
            output_index,
            scale_ue8m0=False,
        )
        
        m_indices_for_cutlass = torch.nn.functional.pad(
            torch.cumsum(num_recv_tokens_per_expert_gpu, dim=0, dtype=torch.int32), (1, 0)
        )

        cutlass_moe_fp8(a=gateup_input_fp8,
                        a_scale=gateup_input_scale,
                        w=self.w13_weight_fp8[0],
                        w_scale=self.w13_weight_fp8[1],
                        c=gateup_output,
                        m_indices=m_indices_for_cutlass)

a_scale,
w_scale,
a_strides,
a_strides,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The stride_b argument to fp8_blockwise_scaled_grouped_mm is incorrectly set to a_strides. The stride for the weight tensor w should be used instead, as a and w have different shapes and layouts.

You should define b_strides based on w's dimensions and pass it here. For example:

b_strides = torch.full((num_experts,), w.stride(1), device=device, dtype=torch.int64)

This should be defined before the fp8_blockwise_scaled_grouped_mm call.

        b_strides = torch.full((num_experts,), w.stride(1), device=device, dtype=torch.int64)
        fp8_blockwise_scaled_grouped_mm(
        ...
        a_strides,
        b_strides,

else self.w2_weight_scale
),
)
self.cutlass_moe_fp8_supported = cutlass_fp8_supported() and (torch.cuda.get_device_capability(torch.cuda.current_device())[0] == 9) and (torch.version.cuda >= "12.3")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This condition is quite complex and long, which affects readability. It can be simplified.

The expression cutlass_fp8_supported() and (torch.cuda.get_device_capability(torch.cuda.current_device())[0] == 9) and (torch.version.cuda >= "12.3") appears to be redundant. The cutlass_fp8_supported() function already checks for device capability.

You could simplify this by extracting the capability and version checks and combining them into a more readable expression. For example:

major, _ = torch.cuda.get_device_capability()
self.cutlass_moe_fp8_supported = (
    cutlass_fp8_supported() and major == 9 and torch.version.cuda >= "12.3"
)

This assumes you only want to support Hopper (SM 90) with CUDA 12.3+, which seems to be the intent.

        major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
        self.cutlass_moe_fp8_supported = (
            cutlass_fp8_supported() and major == 9 and torch.version.cuda >= "12.3"
        )

)

def forward_cutlass_moe(self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter hidden_states_fp8 is a tuple of two tensors, but its name suggests it's a single tensor. This is confusing because the next line unpacks it into hidden_states_fp8, hidden_states_scale, reusing the name hidden_states_fp8.

To improve clarity, I suggest renaming the parameter to reflect that it's a tuple, for example, hidden_states_fp8_and_scale.

        hidden_states_fp8_and_scale: Tuple[torch.Tensor, torch.Tensor],

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants