Skip to content

Conversation

@ruizhang1230
Copy link

@ruizhang1230 ruizhang1230 commented Jun 17, 2025

Motivation

#5010
Adapted from vllm-project/vllm#15956

Modifications

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ruizhang1230, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Mixture of Experts (MoE) capabilities by integrating PPLX kernels, specifically targeting Deepseek models. It introduces a new, optimized distributed MoE backend that leverages NVSHMEM for efficient inter-GPU communication. The changes involve a modular refactoring of the MoE architecture, allowing for a clear separation of concerns between data preparation/finalization and expert computation, and includes new Triton-based kernels for batched MoE operations. This integration provides a high-performance alternative for MoE inference in distributed settings.

Highlights

  • Integration of PPLX Kernels for MoE: Introduces support for PPLX kernels to accelerate Mixture of Experts (MoE) computations, particularly for Deepseek models. This provides a new, optimized distributed MoE backend.
  • Distributed MoE Infrastructure: Implements pplx_init and pplx_finalize functions for initializing and cleaning up NVSHMEM, a critical component for efficient inter-GPU communication required by PPLX kernels in distributed environments.
  • Modular MoE Architecture: Adopts a modular design for MoE kernels, separating preparation/finalization (dispatch/combine) from expert computation. This allows for flexible integration of different communication mechanisms (like PPLX's AllToAll) with various fused MoE kernels (like the new Triton-based batched experts).
  • New Triton-based Batched MoE Kernels: Adds new Triton-based kernels (fused_batched_moe.py) for efficient batched matrix multiplications within MoE layers, designed to work seamlessly with the PPLX dispatch/combine mechanism.
  • Deepseek Model Adaptation: Modifies the DeepseekV2 model's MoE layer to utilize the new PPLX-based MoE implementation when enabled, including handling shared experts and router logic.
  • Configuration Options: Introduces new server arguments (--enable-pplx-moe, --max-tokens-across-dp) to control and configure the PPLX MoE integration.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for PPLX kernels in Deepseek models, which involves significant additions to the MoE layers and distributed setup. The changes appear to be well-adapted from vLLM, but there are several areas that need attention:

  1. Error Handling and Edge Cases: Some error handling could be more specific, and edge cases (e.g., empty tensors with None router logits) in the MoE forward paths need careful review to ensure correctness, especially under CUDA graph capture or compilation.
  2. Clarity and Conventions: Minor points on import locations and the use of NotImplementedError over assert False, "NYI" would improve clarity and adherence to Python conventions.
  3. Quantization Logic: The handling of in_dtype for MoEConfig and assumptions in hidden_dim_scale_bytes calculation need to be robust for current and future quantization schemes.
  4. Testing: The PR checklist indicates that unit tests, formatting, documentation, and benchmarks are pending. These are critical for validating such a substantial change, especially for new Triton kernels and distributed logic.

Overall, the PR makes good progress in integrating PPLX support. Addressing the highlighted points, particularly around correctness in edge cases and adding comprehensive tests, will be key to ensuring the stability and reliability of this feature.

Comment on lines 1547 to 1548
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment TODO (bnell): this needs to be fixed for quantized types. for in_dtype=params_dtype is important. params_dtype refers to the weight dtype, but in_dtype for MoEConfig is documented as "The activation type." If activations are quantized (e.g., to FP8), their dtype will differ from params_dtype (e.g., BF16 weights). This needs careful handling, especially for hidden_dim_bytes and hidden_dim_scale_bytes in _construct_prepare_finalize.

This is a critical point for correctness when quantization is active. Ensure that the in_dtype accurately reflects the activation's data type post-quantization if applicable, or clarify its meaning if it's intended to be the pre-quantization activation dtype.

Comment on lines 1581 to 1586
topk_ids = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device,
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When full_router_logits is None, topk_ids is initialized as an empty tensor with shape (0, self.top_k). If hidden_states is not empty in this case (e.g., hidden_states.shape[0] > 0), passing an empty topk_ids to self.fused_experts might lead to errors or unexpected behavior, as the number of tokens would mismatch.

Ensure that if full_router_logits is None (implying no routing or perhaps a special handling case), either hidden_states should also correspond to zero tokens for the MoE part, or topk_ids and topk_weights should be constructed appropriately for the number of tokens in hidden_states (e.g., if it's a passthrough or some default expert selection).

Comment on lines 347 to 359
def forward_pplx(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch) -> torch.Tensor:
shared_output = None
router_logits = None
if not forward_batch.forward_mode.is_idle() and hidden_states.shape[0] > 0:
router_logits = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
final_hidden_states = (
self.experts(hidden_states, router_logits)
* self.routed_scaling_factor
)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
return final_hidden_states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The forward_pplx method calculates router_logits and shared_output only if not forward_batch.forward_mode.is_idle() and hidden_states.shape[0] > 0. However, self.experts(hidden_states, router_logits) is called unconditionally. If the condition is false, router_logits will be None.

This relies on the PPlxMoE.forward method (and subsequently FusedMoEModularKernel and its components) to correctly handle router_logits=None. As noted in another comment for PPlxMoE.forward, if hidden_states is not empty while router_logits is None (leading to empty topk_ids), this could be problematic.

It would be safer to either ensure hidden_states is also empty/skipped for MoE in this case, or to pass appropriately shaped (even if dummy) router_logits if PPlxMoE expects them for non-empty hidden_states.

Comment on lines 1151 to 1171
def pplx_init(rank, world_size):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None

if has_pplx and world_size > 1:
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id, nvshmem_init)
try:
global PPLX_DID_INIT
logger.debug(
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
"world size=%d", rank, world_size)
uid = nvshmem_get_unique_id(
) if rank == 0 else nvshmem_alloc_empty_unique_id()
uid_gpu = uid.cuda()
get_world_group().broadcast(uid_gpu, src=0)
uid = uid_gpu.to(device='cpu')
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, rank, world_size)
PPLX_DID_INIT = True
except Exception as ex:
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pplx_init function correctly uses @run_once and checks for pplx_kernels. However, the error handling for NVSHMEM initialization could be more specific. Catching a generic Exception might hide specific issues that could be handled or logged differently.

Consider catching more specific exceptions if known (e.g., RuntimeError from pplx_kernels or torch.cuda.CudaError) or at least logging the type of exception in the error message for better diagnostics.

Comment on lines 1180 to 1183
from sglang.srt.layers.moe.ep_moe.layer import (
_all_to_all_cache)
_all_to_all_cache.destroy()
nvshmem_finalize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import from sglang.srt.layers.moe.ep_moe.layer import _all_to_all_cache inside pplx_finalize is a local import. While this works, it's generally preferred to have imports at the top of the file for clarity and to avoid potential circular import issues, though in this specific @run_once context, it might be acceptable to delay the import until first use.

If _all_to_all_cache is lightweight to import or frequently used, consider moving the import to the top. If it's heavy or has specific initialization dependencies tied to PPLX, keeping it local might be justified, but it's worth a comment explaining why if so.

Comment on lines 570 to 573
assert (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()
or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), (
f"{expert_num_tokens} <= {max_num_tokens * num_dp}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion torch.all(expert_num_tokens <= max_num_tokens * num_dp) is only active when not compiling or capturing. If this condition is critical for correctness (e.g., to prevent out-of-bounds access in the subsequent loop), relying on it only outside of compilation/capture might be risky.

If expert_num_tokens can exceed max_num_tokens * num_dp during compilation/capture and this leads to issues, the logic within the loop or the way num is determined might need adjustment for the compiled/captured path as well.

Comment on lines 1282 to 1284
# TODO: can we do del self._cache?
for _, a2a in self._cache.items():
a2a.destroy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The destroy method in AllToAllCache iterates through self._cache.items() to call a2a.destroy(). While WeakValueDictionary automatically removes items when their values are garbage collected, explicitly clearing the dictionary after destroying its contents might be cleaner and more predictable, especially if a2a.destroy() has side effects that might interact with garbage collection or if there's a desire to ensure the cache is fully emptied immediately.

Consider adding self._cache.clear() after the loop if immediate and full clearance is desired.

Comment on lines 1416 to 1418
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) //
moe.block_size * torch.float32.itemsize)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation for hidden_dim_scale_bytes seems to assume that scales are only needed if moe.in_dtype.itemsize == 1 (i.e., for 8-bit types like FP8 or INT8). If other quantized types are introduced in the future that might also require scales (even if their itemsize is not 1), this logic might need to be generalized.

For now, this is likely correct given the current quantization schemes, but it's a point of potential brittleness if new quantization types are added. A comment explaining this assumption could be helpful.

Comment on lines 1554 to 1564
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
self.experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The line experts: Optional[FusedMoEPermuteExpertsUnpermute] = None is declared but not used. It seems like a leftover from a previous thought or refactoring. It can be safely removed.

Suggested change
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
self.experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
dp_size = int(moe.ep_size // moe.dp_size)
self.experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)

Comment on lines +174 to +182
def run_once(f: Callable[P, None]) -> Callable[P, None]:

def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
if not wrapper.has_run: # type: ignore[attr-defined]
wrapper.has_run = True # type: ignore[attr-defined]
return f(*args, **kwargs)

wrapper.has_run = False # type: ignore[attr-defined]
return wrapper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The run_once decorator is a common pattern. However, using type: ignore[attr-defined] for wrapper.has_run can sometimes mask actual issues if the function f itself has complex interactions with attributes. A slightly more robust way to handle state on a wrapper function, especially if you might have multiple instances of such decorated functions, is to use a closure or a class-based decorator.

For this simple case, it's likely fine, but for more complex scenarios or if this pattern proliferates, consider a more structured approach to state management in decorators. For instance, using functools.wraps(f) is also good practice for decorators to preserve metadata of the wrapped function.

def run_once(f: Callable[P, None]) -> Callable[P, None]:
    has_run_lock = threading.Lock()
    _has_run = False # Use a non-local variable for state

    @functools.wraps(f) # Preserve function metadata
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
        nonlocal _has_run
        with has_run_lock:
            if not _has_run:
                _has_run = True
                return f(*args, **kwargs)
    return wrapper

@ch-wan ch-wan self-assigned this Jun 17, 2025
support pplx moe
support attn_tp_size > 1
support moe_dense_tp_size = 1 (Incompatible with cuda graph)
@rizar
Copy link

rizar commented Aug 28, 2025

Are these plans to finish this PR?

@ruizhang1230
Copy link
Author

Are these plans to finish this PR?

Yes, I will refactor the code for this PR according to the latest requirements by this weekend.

This was referenced Sep 13, 2025
@Blueblack319
Copy link

When will this PR be ready to merge?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants