-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Support PPLX kernels for Deepseek models #7272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @ruizhang1230, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the Mixture of Experts (MoE) capabilities by integrating PPLX kernels, specifically targeting Deepseek models. It introduces a new, optimized distributed MoE backend that leverages NVSHMEM for efficient inter-GPU communication. The changes involve a modular refactoring of the MoE architecture, allowing for a clear separation of concerns between data preparation/finalization and expert computation, and includes new Triton-based kernels for batched MoE operations. This integration provides a high-performance alternative for MoE inference in distributed settings.
Highlights
- Integration of PPLX Kernels for MoE: Introduces support for PPLX kernels to accelerate Mixture of Experts (MoE) computations, particularly for Deepseek models. This provides a new, optimized distributed MoE backend.
- Distributed MoE Infrastructure: Implements
pplx_initandpplx_finalizefunctions for initializing and cleaning up NVSHMEM, a critical component for efficient inter-GPU communication required by PPLX kernels in distributed environments. - Modular MoE Architecture: Adopts a modular design for MoE kernels, separating preparation/finalization (dispatch/combine) from expert computation. This allows for flexible integration of different communication mechanisms (like PPLX's AllToAll) with various fused MoE kernels (like the new Triton-based batched experts).
- New Triton-based Batched MoE Kernels: Adds new Triton-based kernels (
fused_batched_moe.py) for efficient batched matrix multiplications within MoE layers, designed to work seamlessly with the PPLX dispatch/combine mechanism. - Deepseek Model Adaptation: Modifies the DeepseekV2 model's MoE layer to utilize the new PPLX-based MoE implementation when enabled, including handling shared experts and router logic.
- Configuration Options: Introduces new server arguments (
--enable-pplx-moe,--max-tokens-across-dp) to control and configure the PPLX MoE integration.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for PPLX kernels in Deepseek models, which involves significant additions to the MoE layers and distributed setup. The changes appear to be well-adapted from vLLM, but there are several areas that need attention:
- Error Handling and Edge Cases: Some error handling could be more specific, and edge cases (e.g., empty tensors with
Nonerouter logits) in the MoE forward paths need careful review to ensure correctness, especially under CUDA graph capture or compilation. - Clarity and Conventions: Minor points on import locations and the use of
NotImplementedErroroverassert False, "NYI"would improve clarity and adherence to Python conventions. - Quantization Logic: The handling of
in_dtypeforMoEConfigand assumptions inhidden_dim_scale_bytescalculation need to be robust for current and future quantization schemes. - Testing: The PR checklist indicates that unit tests, formatting, documentation, and benchmarks are pending. These are critical for validating such a substantial change, especially for new Triton kernels and distributed logic.
Overall, the PR makes good progress in integrating PPLX support. Addressing the highlighted points, particularly around correctness in edge cases and adding comprehensive tests, will be key to ensuring the stability and reliability of this feature.
| # TODO (bnell): this needs to be fixed for quantized types. | ||
| in_dtype=params_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment TODO (bnell): this needs to be fixed for quantized types. for in_dtype=params_dtype is important. params_dtype refers to the weight dtype, but in_dtype for MoEConfig is documented as "The activation type." If activations are quantized (e.g., to FP8), their dtype will differ from params_dtype (e.g., BF16 weights). This needs careful handling, especially for hidden_dim_bytes and hidden_dim_scale_bytes in _construct_prepare_finalize.
This is a critical point for correctness when quantization is active. Ensure that the in_dtype accurately reflects the activation's data type post-quantization if applicable, or clarify its meaning if it's intended to be the pre-quantization activation dtype.
| topk_ids = torch.full( | ||
| (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device, | ||
| ) | ||
| topk_weights = torch.empty( | ||
| (0, self.top_k), dtype=torch.float32, device=hidden_states.device, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When full_router_logits is None, topk_ids is initialized as an empty tensor with shape (0, self.top_k). If hidden_states is not empty in this case (e.g., hidden_states.shape[0] > 0), passing an empty topk_ids to self.fused_experts might lead to errors or unexpected behavior, as the number of tokens would mismatch.
Ensure that if full_router_logits is None (implying no routing or perhaps a special handling case), either hidden_states should also correspond to zero tokens for the MoE part, or topk_ids and topk_weights should be constructed appropriately for the number of tokens in hidden_states (e.g., if it's a passthrough or some default expert selection).
| def forward_pplx(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch) -> torch.Tensor: | ||
| shared_output = None | ||
| router_logits = None | ||
| if not forward_batch.forward_mode.is_idle() and hidden_states.shape[0] > 0: | ||
| router_logits = self.gate(hidden_states) | ||
| shared_output = self._forward_shared_experts(hidden_states) | ||
| final_hidden_states = ( | ||
| self.experts(hidden_states, router_logits) | ||
| * self.routed_scaling_factor | ||
| ) | ||
| if shared_output is not None: | ||
| final_hidden_states = final_hidden_states + shared_output | ||
| return final_hidden_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward_pplx method calculates router_logits and shared_output only if not forward_batch.forward_mode.is_idle() and hidden_states.shape[0] > 0. However, self.experts(hidden_states, router_logits) is called unconditionally. If the condition is false, router_logits will be None.
This relies on the PPlxMoE.forward method (and subsequently FusedMoEModularKernel and its components) to correctly handle router_logits=None. As noted in another comment for PPlxMoE.forward, if hidden_states is not empty while router_logits is None (leading to empty topk_ids), this could be problematic.
It would be safer to either ensure hidden_states is also empty/skipped for MoE in this case, or to pass appropriately shaped (even if dummy) router_logits if PPlxMoE expects them for non-empty hidden_states.
| def pplx_init(rank, world_size): | ||
| has_pplx = importlib.util.find_spec("pplx_kernels") is not None | ||
|
|
||
| if has_pplx and world_size > 1: | ||
| from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, | ||
| nvshmem_get_unique_id, nvshmem_init) | ||
| try: | ||
| global PPLX_DID_INIT | ||
| logger.debug( | ||
| "Initialize NVSHMEM for PPLX kernels: rank=%d, " | ||
| "world size=%d", rank, world_size) | ||
| uid = nvshmem_get_unique_id( | ||
| ) if rank == 0 else nvshmem_alloc_empty_unique_id() | ||
| uid_gpu = uid.cuda() | ||
| get_world_group().broadcast(uid_gpu, src=0) | ||
| uid = uid_gpu.to(device='cpu') | ||
| logger.debug("PPLX NVSHMEM UID = %s", uid) | ||
| nvshmem_init(uid, rank, world_size) | ||
| PPLX_DID_INIT = True | ||
| except Exception as ex: | ||
| logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pplx_init function correctly uses @run_once and checks for pplx_kernels. However, the error handling for NVSHMEM initialization could be more specific. Catching a generic Exception might hide specific issues that could be handled or logged differently.
Consider catching more specific exceptions if known (e.g., RuntimeError from pplx_kernels or torch.cuda.CudaError) or at least logging the type of exception in the error message for better diagnostics.
| from sglang.srt.layers.moe.ep_moe.layer import ( | ||
| _all_to_all_cache) | ||
| _all_to_all_cache.destroy() | ||
| nvshmem_finalize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import from sglang.srt.layers.moe.ep_moe.layer import _all_to_all_cache inside pplx_finalize is a local import. While this works, it's generally preferred to have imports at the top of the file for clarity and to avoid potential circular import issues, though in this specific @run_once context, it might be acceptable to delay the import until first use.
If _all_to_all_cache is lightweight to import or frequently used, consider moving the import to the top. If it's heavy or has specific initialization dependencies tied to PPLX, keeping it local might be justified, but it's worth a comment explaining why if so.
| assert (torch.compiler.is_compiling() | ||
| or torch.cuda.is_current_stream_capturing() | ||
| or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), ( | ||
| f"{expert_num_tokens} <= {max_num_tokens * num_dp}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion torch.all(expert_num_tokens <= max_num_tokens * num_dp) is only active when not compiling or capturing. If this condition is critical for correctness (e.g., to prevent out-of-bounds access in the subsequent loop), relying on it only outside of compilation/capture might be risky.
If expert_num_tokens can exceed max_num_tokens * num_dp during compilation/capture and this leads to issues, the logic within the loop or the way num is determined might need adjustment for the compiled/captured path as well.
| # TODO: can we do del self._cache? | ||
| for _, a2a in self._cache.items(): | ||
| a2a.destroy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The destroy method in AllToAllCache iterates through self._cache.items() to call a2a.destroy(). While WeakValueDictionary automatically removes items when their values are garbage collected, explicitly clearing the dictionary after destroying its contents might be cleaner and more predictable, especially if a2a.destroy() has side effects that might interact with garbage collection or if there's a desire to ensure the cache is fully emptied immediately.
Consider adding self._cache.clear() after the loop if immediate and full clearance is desired.
| hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else | ||
| ((moe.hidden_dim + moe.block_size - 1) // | ||
| moe.block_size * torch.float32.itemsize))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation for hidden_dim_scale_bytes seems to assume that scales are only needed if moe.in_dtype.itemsize == 1 (i.e., for 8-bit types like FP8 or INT8). If other quantized types are introduced in the future that might also require scales (even if their itemsize is not 1), this logic might need to be generalized.
For now, this is likely correct given the current quantization schemes, but it's a point of potential brittleness if new quantization types are added. A comment explaining this assumption could be helpful.
| experts: Optional[FusedMoEPermuteExpertsUnpermute] = None | ||
| self.experts = BatchedTritonExperts( | ||
| max_num_tokens=MOE_DP_CHUNK_SIZE, | ||
| world_size=world_size, | ||
| dp_size=dp_size, | ||
| use_fp8_w8a8=False, | ||
| use_int8_w8a8=False, | ||
| use_int8_w8a16=False, | ||
| use_int4_w4a16=False, | ||
| block_shape=None, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The line experts: Optional[FusedMoEPermuteExpertsUnpermute] = None is declared but not used. It seems like a leftover from a previous thought or refactoring. It can be safely removed.
| experts: Optional[FusedMoEPermuteExpertsUnpermute] = None | |
| self.experts = BatchedTritonExperts( | |
| max_num_tokens=MOE_DP_CHUNK_SIZE, | |
| world_size=world_size, | |
| dp_size=dp_size, | |
| use_fp8_w8a8=False, | |
| use_int8_w8a8=False, | |
| use_int8_w8a16=False, | |
| use_int4_w4a16=False, | |
| block_shape=None, | |
| ) | |
| dp_size = int(moe.ep_size // moe.dp_size) | |
| self.experts = BatchedTritonExperts( | |
| max_num_tokens=MOE_DP_CHUNK_SIZE, | |
| world_size=world_size, | |
| dp_size=dp_size, | |
| use_fp8_w8a8=False, | |
| use_int8_w8a8=False, | |
| use_int8_w8a16=False, | |
| use_int4_w4a16=False, | |
| block_shape=None, | |
| ) |
| def run_once(f: Callable[P, None]) -> Callable[P, None]: | ||
|
|
||
| def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: | ||
| if not wrapper.has_run: # type: ignore[attr-defined] | ||
| wrapper.has_run = True # type: ignore[attr-defined] | ||
| return f(*args, **kwargs) | ||
|
|
||
| wrapper.has_run = False # type: ignore[attr-defined] | ||
| return wrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The run_once decorator is a common pattern. However, using type: ignore[attr-defined] for wrapper.has_run can sometimes mask actual issues if the function f itself has complex interactions with attributes. A slightly more robust way to handle state on a wrapper function, especially if you might have multiple instances of such decorated functions, is to use a closure or a class-based decorator.
For this simple case, it's likely fine, but for more complex scenarios or if this pattern proliferates, consider a more structured approach to state management in decorators. For instance, using functools.wraps(f) is also good practice for decorators to preserve metadata of the wrapped function.
def run_once(f: Callable[P, None]) -> Callable[P, None]:
has_run_lock = threading.Lock()
_has_run = False # Use a non-local variable for state
@functools.wraps(f) # Preserve function metadata
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
nonlocal _has_run
with has_run_lock:
if not _has_run:
_has_run = True
return f(*args, **kwargs)
return wrappersupport pplx moe support attn_tp_size > 1 support moe_dense_tp_size = 1 (Incompatible with cuda graph)
|
Are these plans to finish this PR? |
Yes, I will refactor the code for this PR according to the latest requirements by this weekend. |
|
When will this PR be ready to merge? |
Motivation
#5010
Adapted from vllm-project/vllm#15956
Modifications
Checklist