From f9da4cf4ea031985b8c619a6c208b5053a8ae4aa Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:00:36 +0800 Subject: [PATCH 01/65] more --- python/sglang/srt/server_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1a19bbea225..debcdf8bb79 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,7 +20,7 @@ import os import random import tempfile -from typing import List, Optional +from typing import List, Optional, Literal from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.reasoning_parser import ReasoningParser @@ -161,7 +161,7 @@ class ServerArgs: enable_dp_attention: bool = False enable_ep_moe: bool = False enable_deepep_moe: bool = False - deepep_mode: Optional[str] = "auto" + deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None From f51b1d22393cf6d1819bdfb612cdb0e848c1286b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:03:13 +0800 Subject: [PATCH 02/65] more --- python/sglang/srt/utils.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 498bc58ccd0..1208dae5363 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -37,6 +37,7 @@ import traceback import warnings from contextlib import contextmanager +from enum import Enum from functools import lru_cache from importlib.metadata import PackageNotFoundError, version from importlib.util import find_spec @@ -53,10 +54,10 @@ import torch.distributed as dist import triton import zmq +from PIL import Image from decord import VideoReader, cpu from fastapi.responses import ORJSONResponse from packaging import version as pkg_version -from PIL import Image from starlette.routing import Mount from torch import nn from torch.func import functional_call @@ -913,10 +914,10 @@ def get_zmq_socket( context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool ): mem = psutil.virtual_memory() - total_mem = mem.total / 1024**3 - available_mem = mem.available / 1024**3 + total_mem = mem.total / 1024 ** 3 + available_mem = mem.available / 1024 ** 3 if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024**3) + buf_size = int(0.5 * 1024 ** 3) else: buf_size = -1 @@ -1477,10 +1478,10 @@ def dataclass_to_string_truncated( return ( "{" + ", ".join( - f"'{k}': {dataclass_to_string_truncated(v, max_length)}" - for k, v in data.items() - if k not in skip_names - ) + f"'{k}': {dataclass_to_string_truncated(v, max_length)}" + for k, v in data.items() + if k not in skip_names + ) + "}" ) elif dataclasses.is_dataclass(data): @@ -1488,10 +1489,10 @@ def dataclass_to_string_truncated( return ( f"{data.__class__.__name__}(" + ", ".join( - f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" - for f in fields - if f.name not in skip_names - ) + f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" + for f in fields + if f.name not in skip_names + ) + ")" ) else: @@ -1681,7 +1682,7 @@ def configure_ipv6(dist_init_addr): port_str = None if len(addr) > end + 1: if addr[end + 1] == ":": - port_str = addr[end + 2 :] + port_str = addr[end + 2:] else: raise ValueError("received IPv6 address format: expected ':' after ']'") @@ -1819,7 +1820,7 @@ def retry( if not should_retry(e): raise Exception(f"retry() observe errors that should not be retried.") - delay = min(initial_delay * (2**try_index), max_delay) * ( + delay = min(initial_delay * (2 ** try_index), max_delay) * ( 0.75 + 0.25 * random.random() ) @@ -1838,3 +1839,10 @@ def flatten_nested_list(nested_list): ] else: return [nested_list] + + +# TODO where should we put it? +class DeepEPMode(Enum): + NORMAL = "normal" + LOW_LATENCY = "low_latency" + AUTO = "auto" From 5662e6f75e07dc1d8123105824a29e6f1f4ff983 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:04:45 +0800 Subject: [PATCH 03/65] more --- python/sglang/srt/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 1208dae5363..2266dc08c48 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1843,6 +1843,6 @@ def flatten_nested_list(nested_list): # TODO where should we put it? class DeepEPMode(Enum): - NORMAL = "normal" - LOW_LATENCY = "low_latency" - AUTO = "auto" + normal = "normal" + low_latency = "low_latency" + auto = "auto" From 7571008b1282f5667f340d0342b7ec204fbb0724 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:05:14 +0800 Subject: [PATCH 04/65] more --- python/sglang/srt/models/deepseek_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 6aaa3744a86..7e17ee2693e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -70,7 +70,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix, is_cuda, is_hip +from sglang.srt.utils import add_prefix, is_cuda, is_hip, DeepEPMode _is_hip = is_hip() _is_cuda = is_cuda() @@ -215,7 +215,7 @@ def __init__( topk_group=config.topk_group, correction_bias=self.gate.e_score_correction_bias, prefix=add_prefix("experts", prefix), - deepep_mode=global_server_args_dict["deepep_mode"], + deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], ) if config.n_shared_experts is not None: @@ -264,7 +264,7 @@ def __init__( num_local_experts=config.n_routed_experts // self.tp_size, hidden_size=config.hidden_size, params_dtype=config.torch_dtype, - deepep_mode=global_server_args_dict["deepep_mode"], + deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], async_finish=True, # TODO return_recv_hook=True, ) From 1480474d9291e4b79be5da9e5994e1ac7d16b925 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:05:41 +0800 Subject: [PATCH 05/65] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 814dc469e93..877eb5d1626 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -38,7 +38,7 @@ ) from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs +from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs, DeepEPMode _is_cuda = is_cuda() @@ -814,7 +814,7 @@ def __init__( correction_bias: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, activation: str = "silu", - deepep_mode: str = "auto", + deepep_mode: DeepEPMode = DeepEPMode.auto, ): super().__init__( num_experts, From f08c09e628377341153033d3bf43339857360b01 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:05:50 +0800 Subject: [PATCH 06/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index f4e673535f3..5c0f1a5eb5d 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,3 +1,5 @@ +from sglang.srt.utils import DeepEPMode + try: from deep_ep import Buffer @@ -98,7 +100,7 @@ def __init__( num_local_experts: int = None, hidden_size: int = None, params_dtype: torch.dtype = None, - deepep_mode: str = "auto", + deepep_mode: DeepEPMode = DeepEPMode.auto, async_finish: bool = False, return_recv_hook: bool = False, ): From 1036ab9691bcdd960ce3b4f31e5cb5c668f830d5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:06:45 +0800 Subject: [PATCH 07/65] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 2 +- python/sglang/srt/utils.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 877eb5d1626..6274e42cda1 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -834,7 +834,7 @@ def __init__( activation, ) self.deepep_mode = deepep_mode - if self.deepep_mode in ["low_latency", "auto"]: + if self.deepep_mode.enable_low_latency(): assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm" self.w13_weight_fp8 = ( self.w13_weight, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 2266dc08c48..18ded2ccce9 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1846,3 +1846,9 @@ class DeepEPMode(Enum): normal = "normal" low_latency = "low_latency" auto = "auto" + + def enable_normal(self): + return self in [DeepEPMode.normal, DeepEPMode.auto] + + def enable_low_latency(self): + return self in [DeepEPMode.low_latency, DeepEPMode.auto] From ad1dfce0b174fbcf4079c401347b96851a5fb2b1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:09:14 +0800 Subject: [PATCH 08/65] more --- python/sglang/srt/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 18ded2ccce9..c19ff2c7246 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1852,3 +1852,13 @@ def enable_normal(self): def enable_low_latency(self): return self in [DeepEPMode.low_latency, DeepEPMode.auto] + + def resolve(self, forward_mode): + if self != DeepEPMode.auto: + return self + + if forward_mode.is_decode(): + return DeepEPMode.low_latency + else: + return DeepEPMode.normal + From c6859e07900ade9e044dc6269c086376019eda38 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:09:54 +0800 Subject: [PATCH 09/65] more --- python/sglang/srt/layers/moe/ep_moe/layer.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 6274e42cda1..18c2c52f323 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -47,7 +47,6 @@ else: from vllm import _custom_ops as vllm_ops - logger = logging.getLogger(__name__) _is_hip = is_hip() @@ -259,7 +258,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): BLOCK_SIZE=512, ) - seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] + seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, @@ -436,7 +435,7 @@ def weight_loader( elif shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight elif shard_id == "w3": - param.data[expert_id][self.intermediate_size :, :] = loaded_weight + param.data[expert_id][self.intermediate_size:, :] = loaded_weight else: raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") @@ -468,11 +467,11 @@ def _load_fp8_scale( block_n, block_k = self.block_shape[0], self.block_shape[1] if shard_id == "w1": param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : + : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight elif shard_id == "w3": param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n :, : + (self.intermediate_size + block_n - 1) // block_n:, : ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight @@ -858,13 +857,10 @@ def forward( expected_m: int, forward_mode: ForwardMode, ): - if self.deepep_mode == "normal" or ( - self.deepep_mode == "auto" and not forward_mode.is_decode() - ): + resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) + if resolved_deepep_mode == DeepEPMode.normal: return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr) - elif self.deepep_mode == "low_latency" or ( - self.deepep_mode == "auto" and forward_mode.is_decode() - ): + elif resolved_deepep_mode == DeepEPMode.low_latency: return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m) else: raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") From 94fbf1f4abba35470ba6a4de803c54482c11b3e0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:10:36 +0800 Subject: [PATCH 10/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 5c0f1a5eb5d..3824c8ad1a0 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -122,13 +122,13 @@ def __init__( self.deepep_mode = deepep_mode self.handle = None - if self.deepep_mode in ["normal", "auto"]: # for normal / auto mode + if self.deepep_mode.enable_normal(): self.buffer_normal = get_buffer_normal( self.group, self.hidden_size * self.params_bytes ) self.async_finish = async_finish self.src2dst = None - if self.deepep_mode in ["low_latency", "auto"]: # for low_latency / auto mode + if self.deepep_mode.enable_low_latency(): """ num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding @@ -216,11 +216,11 @@ def dispatch( self.deepep_mode == "auto" and forward_mode.is_decode() ): expected_m = ( - hidden_states.shape[0] - * self.buffer_low_latency.group_size - * topk_idx.shape[1] - + num_experts - ) // num_experts + hidden_states.shape[0] + * self.buffer_low_latency.group_size + * topk_idx.shape[1] + + num_experts + ) // num_experts hidden_states, masked_m, event, hook = self.dispatch_low_latency( hidden_states, topk_idx, From 0004ba713aa6ab0c03a8582939a9177ad221b293 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:11:03 +0800 Subject: [PATCH 11/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 3824c8ad1a0..4b42590755f 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -198,9 +198,8 @@ def dispatch( ) expected_m = 0 - if self.deepep_mode == "normal" or ( - self.deepep_mode == "auto" and not forward_mode.is_decode() - ): + resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) + if resolved_deepep_mode == DeepEPMode.normal: ( hidden_states, topk_idx, @@ -212,9 +211,7 @@ def dispatch( reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute( hidden_states, topk_idx, fp8_dtype=hidden_states.dtype ) - elif self.deepep_mode == "low_latency" or ( - self.deepep_mode == "auto" and forward_mode.is_decode() - ): + elif resolved_deepep_mode == DeepEPMode.low_latency: expected_m = ( hidden_states.shape[0] * self.buffer_low_latency.group_size From 09668abdc28c7c79af26d249b75db8957fde2d20 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:11:15 +0800 Subject: [PATCH 12/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 4b42590755f..619dda959e8 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -353,9 +353,8 @@ def combine( topk_weights: torch.Tensor, forward_mode: ForwardMode, ) -> torch.Tensor: - if self.deepep_mode == "normal" or ( - self.deepep_mode == "auto" and not forward_mode.is_decode() - ): + resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) + if resolved_deepep_mode == DeepEPMode.normal: if hidden_states.shape[0] > 0: num_tokens = self.src2dst.shape[0] // self.router_topk output = torch.empty( @@ -383,9 +382,7 @@ def combine( output, ) event.current_stream_wait() if self.async_finish else () - elif self.deepep_mode == "low_latency" or ( - self.deepep_mode == "auto" and forward_mode.is_decode() - ): + elif resolved_deepep_mode == DeepEPMode.low_latency: hidden_states, event, hook = self.combine_low_latency( hidden_states, topk_idx, From 41e7d9871ad4e408dbfbea28ebd25ac3500bb46e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:13:14 +0800 Subject: [PATCH 13/65] fmt --- python/sglang/srt/layers/moe/ep_moe/layer.py | 10 +++---- .../srt/layers/moe/ep_moe/token_dispatcher.py | 10 +++---- python/sglang/srt/models/deepseek_v2.py | 2 +- python/sglang/srt/server_args.py | 2 +- python/sglang/srt/utils.py | 29 +++++++++---------- 5 files changed, 26 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 18c2c52f323..dfecb63d940 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -38,7 +38,7 @@ ) from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs, DeepEPMode +from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs _is_cuda = is_cuda() @@ -258,7 +258,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): BLOCK_SIZE=512, ) - seg_indptr_cur_rank = seg_indptr[self.start_expert_id: self.end_expert_id + 2] + seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, self.num_experts_per_partition, @@ -435,7 +435,7 @@ def weight_loader( elif shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight elif shard_id == "w3": - param.data[expert_id][self.intermediate_size:, :] = loaded_weight + param.data[expert_id][self.intermediate_size :, :] = loaded_weight else: raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") @@ -467,11 +467,11 @@ def _load_fp8_scale( block_n, block_k = self.block_shape[0], self.block_shape[1] if shard_id == "w1": param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : + : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight elif shard_id == "w3": param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n:, : + (self.intermediate_size + block_n - 1) // block_n :, : ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 619dda959e8..2a290981627 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -213,11 +213,11 @@ def dispatch( ) elif resolved_deepep_mode == DeepEPMode.low_latency: expected_m = ( - hidden_states.shape[0] - * self.buffer_low_latency.group_size - * topk_idx.shape[1] - + num_experts - ) // num_experts + hidden_states.shape[0] + * self.buffer_low_latency.group_size + * topk_idx.shape[1] + + num_experts + ) // num_experts hidden_states, masked_m, event, hook = self.dispatch_low_latency( hidden_states, topk_idx, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 7e17ee2693e..adbc76b9e95 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -70,7 +70,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix, is_cuda, is_hip, DeepEPMode +from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip _is_hip = is_hip() _is_cuda = is_cuda() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index debcdf8bb79..f025a4b5506 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,7 +20,7 @@ import os import random import tempfile -from typing import List, Optional, Literal +from typing import List, Literal, Optional from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.reasoning_parser import ReasoningParser diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index c19ff2c7246..f78b953a642 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -54,10 +54,10 @@ import torch.distributed as dist import triton import zmq -from PIL import Image from decord import VideoReader, cpu from fastapi.responses import ORJSONResponse from packaging import version as pkg_version +from PIL import Image from starlette.routing import Mount from torch import nn from torch.func import functional_call @@ -914,10 +914,10 @@ def get_zmq_socket( context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool ): mem = psutil.virtual_memory() - total_mem = mem.total / 1024 ** 3 - available_mem = mem.available / 1024 ** 3 + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024 ** 3) + buf_size = int(0.5 * 1024**3) else: buf_size = -1 @@ -1478,10 +1478,10 @@ def dataclass_to_string_truncated( return ( "{" + ", ".join( - f"'{k}': {dataclass_to_string_truncated(v, max_length)}" - for k, v in data.items() - if k not in skip_names - ) + f"'{k}': {dataclass_to_string_truncated(v, max_length)}" + for k, v in data.items() + if k not in skip_names + ) + "}" ) elif dataclasses.is_dataclass(data): @@ -1489,10 +1489,10 @@ def dataclass_to_string_truncated( return ( f"{data.__class__.__name__}(" + ", ".join( - f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" - for f in fields - if f.name not in skip_names - ) + f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" + for f in fields + if f.name not in skip_names + ) + ")" ) else: @@ -1682,7 +1682,7 @@ def configure_ipv6(dist_init_addr): port_str = None if len(addr) > end + 1: if addr[end + 1] == ":": - port_str = addr[end + 2:] + port_str = addr[end + 2 :] else: raise ValueError("received IPv6 address format: expected ':' after ']'") @@ -1820,7 +1820,7 @@ def retry( if not should_retry(e): raise Exception(f"retry() observe errors that should not be retried.") - delay = min(initial_delay * (2 ** try_index), max_delay) * ( + delay = min(initial_delay * (2**try_index), max_delay) * ( 0.75 + 0.25 * random.random() ) @@ -1861,4 +1861,3 @@ def resolve(self, forward_mode): return DeepEPMode.low_latency else: return DeepEPMode.normal - From 32dd0571bcb2395c2bfcd8c758388adb51d1c655 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:14:21 +0800 Subject: [PATCH 14/65] more --- python/sglang/srt/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f78b953a642..ba229b1ce9b 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1841,7 +1841,6 @@ def flatten_nested_list(nested_list): return [nested_list] -# TODO where should we put it? class DeepEPMode(Enum): normal = "normal" low_latency = "low_latency" From 1836a7945c012b82de2e8ad97c6c07bda71e86b8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:16:02 +0800 Subject: [PATCH 15/65] more --- python/sglang/srt/models/deepseek_v2.py | 44 +++++++++---------------- 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index adbc76b9e95..edc845162cf 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -188,35 +188,21 @@ def __init__( if global_server_args_dict["enable_deepep_moe"] else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) ) - if not global_server_args_dict["enable_deepep_moe"]: - self.experts = MoEImpl( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - prefix=add_prefix("experts", prefix), - ) - else: - self.experts = MoEImpl( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - prefix=add_prefix("experts", prefix), - deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], - ) + self.experts = MoEImpl( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + prefix=add_prefix("experts", prefix), + **(dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) + if global_server_args_dict["enable_deepep_moe"] else {}), + ) if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts From 7e276bdff88878264b511c914e113d17b3b3f71a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:16:38 +0800 Subject: [PATCH 16/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 2a290981627..242ac031fb0 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -85,6 +85,18 @@ def get_buffer_low_latency( return _buffer_low_latency +class _DeepEPDispatcherBase: + TODO + + +class _DeepEPDispatcherNormal(_DeepEPDispatcherBase): + TODO + + +class _DeepEPDispatcherLowLatency(_DeepEPDispatcherBase): + TODO + + class DeepEPDispatcher: """ Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher @@ -213,11 +225,11 @@ def dispatch( ) elif resolved_deepep_mode == DeepEPMode.low_latency: expected_m = ( - hidden_states.shape[0] - * self.buffer_low_latency.group_size - * topk_idx.shape[1] - + num_experts - ) // num_experts + hidden_states.shape[0] + * self.buffer_low_latency.group_size + * topk_idx.shape[1] + + num_experts + ) // num_experts hidden_states, masked_m, event, hook = self.dispatch_low_latency( hidden_states, topk_idx, From ef08c3247c26ca4057519ceab856294a5a684865 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:16:50 +0800 Subject: [PATCH 17/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 242ac031fb0..0189d8ea9e0 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -97,6 +97,10 @@ class _DeepEPDispatcherLowLatency(_DeepEPDispatcherBase): TODO +class DeepEPDispatcher: + TODO + + class DeepEPDispatcher: """ Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher From 9138b761f77e081547c042cb9d589f3bed71fd41 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:17:20 +0800 Subject: [PATCH 18/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 0189d8ea9e0..1936523a909 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -86,27 +86,6 @@ def get_buffer_low_latency( class _DeepEPDispatcherBase: - TODO - - -class _DeepEPDispatcherNormal(_DeepEPDispatcherBase): - TODO - - -class _DeepEPDispatcherLowLatency(_DeepEPDispatcherBase): - TODO - - -class DeepEPDispatcher: - TODO - - -class DeepEPDispatcher: - """ - Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py - """ - def __init__( self, group: torch.distributed.ProcessGroup, @@ -159,6 +138,27 @@ def __init__( ) self.return_recv_hook = return_recv_hook + +class _DeepEPDispatcherNormal(_DeepEPDispatcherBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class _DeepEPDispatcherLowLatency(_DeepEPDispatcherBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class DeepEPDispatcher: + TODO + + +class DeepEPDispatcher: + """ + Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py + """ + def deepep_permute( self, hidden_states: torch.Tensor, From 292eae9cd347fc4d9d856c93774eff896f8ec0bb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:17:41 +0800 Subject: [PATCH 19/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 1936523a909..1eb690d1194 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -90,14 +90,14 @@ def __init__( self, group: torch.distributed.ProcessGroup, router_topk: int, - permute_fusion: bool = False, - num_experts: int = None, - num_local_experts: int = None, - hidden_size: int = None, - params_dtype: torch.dtype = None, - deepep_mode: DeepEPMode = DeepEPMode.auto, - async_finish: bool = False, - return_recv_hook: bool = False, + permute_fusion: bool, + num_experts: int, + num_local_experts: int, + hidden_size: int, + params_dtype: torch.dtype, + deepep_mode: DeepEPMode, + async_finish: bool, + return_recv_hook: bool, ): if not use_deepep: raise ImportError( @@ -150,7 +150,20 @@ def __init__(self, *args, **kwargs): class DeepEPDispatcher: - TODO + def __init__( + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool = False, + num_experts: int = None, + num_local_experts: int = None, + hidden_size: int = None, + params_dtype: torch.dtype = None, + deepep_mode: DeepEPMode = DeepEPMode.auto, + async_finish: bool = False, + return_recv_hook: bool = False, + ): + TODO class DeepEPDispatcher: From b7b32de04ffcee73310f9ca918c71f7c59c45b97 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:18:22 +0800 Subject: [PATCH 20/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 1eb690d1194..2f92ba4493c 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -95,7 +95,6 @@ def __init__( num_local_experts: int, hidden_size: int, params_dtype: torch.dtype, - deepep_mode: DeepEPMode, async_finish: bool, return_recv_hook: bool, ): @@ -114,40 +113,38 @@ def __init__( self.params_dtype = params_dtype self.params_bytes = 2 - self.deepep_mode = deepep_mode self.handle = None - if self.deepep_mode.enable_normal(): - self.buffer_normal = get_buffer_normal( - self.group, self.hidden_size * self.params_bytes - ) - self.async_finish = async_finish - self.src2dst = None - if self.deepep_mode.enable_low_latency(): - """ - num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 - https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding - """ - # TODO(ch-wan): allow users to set this value - self.num_max_dispatch_tokens_per_rank = 128 - self.buffer_low_latency = get_buffer_low_latency( - self.group, - self.num_max_dispatch_tokens_per_rank, - self.hidden_size, - self.num_experts, - ) - self.return_recv_hook = return_recv_hook - class _DeepEPDispatcherNormal(_DeepEPDispatcherBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.buffer_normal = get_buffer_normal( + self.group, self.hidden_size * self.params_bytes + ) + self.async_finish = async_finish + self.src2dst = None + class _DeepEPDispatcherLowLatency(_DeepEPDispatcherBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + """ + num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 + https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding + """ + # TODO(ch-wan): allow users to set this value + self.num_max_dispatch_tokens_per_rank = 128 + self.buffer_low_latency = get_buffer_low_latency( + self.group, + self.num_max_dispatch_tokens_per_rank, + self.hidden_size, + self.num_experts, + ) + self.return_recv_hook = return_recv_hook + class DeepEPDispatcher: def __init__( @@ -163,7 +160,11 @@ def __init__( async_finish: bool = False, return_recv_hook: bool = False, ): - TODO + self.deepep_mode = deepep_mode + if self.deepep_mode.enable_normal(): + TODO + if self.deepep_mode.enable_low_latency(): + TODO class DeepEPDispatcher: From 7c95a37dabbc40b1082c6822a6b83f77e5b8b1d6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:19:14 +0800 Subject: [PATCH 21/65] more --- .../sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 2f92ba4493c..6fea7f6ef61 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -95,8 +95,6 @@ def __init__( num_local_experts: int, hidden_size: int, params_dtype: torch.dtype, - async_finish: bool, - return_recv_hook: bool, ): if not use_deepep: raise ImportError( @@ -117,8 +115,8 @@ def __init__( class _DeepEPDispatcherNormal(_DeepEPDispatcherBase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, async_finish: bool, **kwargs): + super().__init__(**kwargs) self.buffer_normal = get_buffer_normal( self.group, self.hidden_size * self.params_bytes @@ -128,8 +126,8 @@ def __init__(self, *args, **kwargs): class _DeepEPDispatcherLowLatency(_DeepEPDispatcherBase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, return_recv_hook: bool, **kwargs): + super().__init__(**kwargs) """ num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 From 74b99e0c2a4eec886d0e03e09e29bac304ac5408 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:19:40 +0800 Subject: [PATCH 22/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 6fea7f6ef61..4b921dc0a0e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -160,9 +160,9 @@ def __init__( ): self.deepep_mode = deepep_mode if self.deepep_mode.enable_normal(): - TODO + self._normal_dispatcher = _DeepEPDispatcherNormal() if self.deepep_mode.enable_low_latency(): - TODO + self._low_latency_dispatcher = _DeepEPDispatcherLowLatency() class DeepEPDispatcher: From 56ce827166ed54ef2d9355ed05565dc545253c2e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:20:20 +0800 Subject: [PATCH 23/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 4b921dc0a0e..7899bfe5249 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -159,10 +159,27 @@ def __init__( return_recv_hook: bool = False, ): self.deepep_mode = deepep_mode + + common_kwargs = dict( + group=group, + router_topk=router_topk, + permute_fusion=permute_fusion, + num_experts=num_experts, + num_local_experts=num_local_experts, + hidden_size=hidden_size, + params_dtype=params_dtype, + ) + if self.deepep_mode.enable_normal(): - self._normal_dispatcher = _DeepEPDispatcherNormal() + self._normal_dispatcher = _DeepEPDispatcherNormal( + async_finish=async_finish, + **common_kwargs, + ) if self.deepep_mode.enable_low_latency(): - self._low_latency_dispatcher = _DeepEPDispatcherLowLatency() + self._low_latency_dispatcher = _DeepEPDispatcherLowLatency( + return_recv_hook=return_recv_hook, + **common_kwargs, + ) class DeepEPDispatcher: From 011dfd7fafec250d94184481438b9c0fd9ca533d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:20:51 +0800 Subject: [PATCH 24/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 7899bfe5249..2144ca2730d 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -124,6 +124,45 @@ def __init__(self, async_finish: bool, **kwargs): self.async_finish = async_finish self.src2dst = None + """ + Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py + """ + + def _deepep_permute( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + fp8_dtype: Optional[torch.dtype] = None, + use_fp8_w8a8: bool = False, + use_block_quant: bool = False, + ): + reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( + topk_idx, self.num_experts + ) + num_total_tokens = reorder_topk_ids.numel() + gateup_input = torch.empty( + (int(num_total_tokens), hidden_states.shape[1]), + device=hidden_states.device, + dtype=( + fp8_dtype + if (use_fp8_w8a8 and not use_block_quant) + else hidden_states.dtype + ), + ) + # PreReorder + deepep_permute_triton_kernel[(hidden_states.shape[0],)]( + hidden_states, + gateup_input, + self.src2dst, + topk_idx, + None, + self.router_topk, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + return reorder_topk_ids, seg_indptr, gateup_input + class _DeepEPDispatcherLowLatency(_DeepEPDispatcherBase): def __init__(self, return_recv_hook: bool, **kwargs): @@ -183,45 +222,6 @@ def __init__( class DeepEPDispatcher: - """ - Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py - """ - - def deepep_permute( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - fp8_dtype: Optional[torch.dtype] = None, - use_fp8_w8a8: bool = False, - use_block_quant: bool = False, - ): - reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( - topk_idx, self.num_experts - ) - num_total_tokens = reorder_topk_ids.numel() - gateup_input = torch.empty( - (int(num_total_tokens), hidden_states.shape[1]), - device=hidden_states.device, - dtype=( - fp8_dtype - if (use_fp8_w8a8 and not use_block_quant) - else hidden_states.dtype - ), - ) - # PreReorder - deepep_permute_triton_kernel[(hidden_states.shape[0],)]( - hidden_states, - gateup_input, - self.src2dst, - topk_idx, - None, - self.router_topk, - hidden_states.shape[1], - BLOCK_SIZE=512, - ) - return reorder_topk_ids, seg_indptr, gateup_input - def dispatch( self, hidden_states: torch.Tensor, @@ -253,7 +253,7 @@ def dispatch( ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) event.current_stream_wait() if self.async_finish else () if hidden_states.shape[0] > 0: - reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute( + reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( hidden_states, topk_idx, fp8_dtype=hidden_states.dtype ) elif resolved_deepep_mode == DeepEPMode.low_latency: From 7c715eb6c601b4c939571c17605111ea87f9814b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:21:20 +0800 Subject: [PATCH 25/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 218 +++++++++--------- 1 file changed, 109 insertions(+), 109 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 2144ca2730d..8152caa832d 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -163,6 +163,59 @@ def _deepep_permute( ) return reorder_topk_ids, seg_indptr, gateup_input + def _dispatch_normal( + self, + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + ): + previous_event = Buffer.capture() if self.async_finish else None + + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + previous_event, + ) = self.buffer_normal.get_dispatch_layout( + topk_idx, + num_experts, + previous_event=previous_event, + async_finish=self.async_finish, + allocate_on_comm_stream=previous_event is not None, + ) + + # FIXME: `handle` should be transmitted with tokens from dispatch to combine. + # However, doing this would incur an unknown synchronization error, but keeping + # `handle` as a member variable works. + ( + recv_x, + recv_topk_idx, + recv_topk_weights, + _, # num_recv_tokens_per_expert_list + self.handle, + event, + ) = self.buffer_normal.dispatch( + x, + topk_idx=topk_idx, + topk_weights=topk_weights, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=self.async_finish, + allocate_on_comm_stream=(previous_event is not None) and self.async_finish, + ) + + return ( + recv_x, + recv_topk_idx, + recv_topk_weights, + event, + ) + class _DeepEPDispatcherLowLatency(_DeepEPDispatcherBase): def __init__(self, return_recv_hook: bool, **kwargs): @@ -182,6 +235,60 @@ def __init__(self, return_recv_hook: bool, **kwargs): ) self.return_recv_hook = return_recv_hook + def _dispatch_low_latency( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + num_max_dispatch_tokens_per_rank: int, + num_experts: int, + use_fp8: bool = False, + ): + """ + # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'. + # Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall. + # More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782 + + diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu + index 76ae2e2..8ecd08f 100644 + --- a/csrc/kernels/internode_ll.cu + +++ b/csrc/kernels/internode_ll.cu + @@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, + int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, + void* workspace, cudaStream_t stream, int phases) { + constexpr int kNumMaxTopK = 9; + - constexpr int kNumWarpsPerGroup = 10; + - constexpr int kNumWarpGroups = 3; + + constexpr int kNumWarpsPerGroup = 8; + + constexpr int kNumWarpGroups = 4; + EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections"); + + const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + @@ -501,8 +501,8 @@ void combine(void* combined_x, + int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + void* workspace, cudaStream_t stream, int phases) { + - constexpr int kNumWarpsPerGroup = 10; + - constexpr int kNumWarpGroups = 3; + + constexpr int kNumWarpsPerGroup = 8; + + constexpr int kNumWarpGroups = 4; + constexpr int kNumMaxTopk = 9; + + const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + """ + + packed_recv_hidden, packed_recv_count, self.handle, event, hook = ( + self.buffer_low_latency.low_latency_dispatch( + hidden_states, + topk_idx, + num_max_dispatch_tokens_per_rank, + num_experts, + use_fp8=use_fp8, + async_finish=not self.return_recv_hook, + return_recv_hook=self.return_recv_hook, + ) + ) + return packed_recv_hidden, packed_recv_count, event, hook + class DeepEPDispatcher: def __init__( @@ -250,7 +357,7 @@ def dispatch( topk_idx, topk_weights, event, - ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) + ) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) event.current_stream_wait() if self.async_finish else () if hidden_states.shape[0] > 0: reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( @@ -263,7 +370,7 @@ def dispatch( * topk_idx.shape[1] + num_experts ) // num_experts - hidden_states, masked_m, event, hook = self.dispatch_low_latency( + hidden_states, masked_m, event, hook = self._dispatch_low_latency( hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, @@ -284,113 +391,6 @@ def dispatch( expected_m, ) - def dispatch_normal( - self, - x: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - num_experts: int, - ): - previous_event = Buffer.capture() if self.async_finish else None - - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = self.buffer_normal.get_dispatch_layout( - topk_idx, - num_experts, - previous_event=previous_event, - async_finish=self.async_finish, - allocate_on_comm_stream=previous_event is not None, - ) - - # FIXME: `handle` should be transmitted with tokens from dispatch to combine. - # However, doing this would incur an unknown synchronization error, but keeping - # `handle` as a member variable works. - ( - recv_x, - recv_topk_idx, - recv_topk_weights, - _, # num_recv_tokens_per_expert_list - self.handle, - event, - ) = self.buffer_normal.dispatch( - x, - topk_idx=topk_idx, - topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=self.async_finish, - allocate_on_comm_stream=(previous_event is not None) and self.async_finish, - ) - - return ( - recv_x, - recv_topk_idx, - recv_topk_weights, - event, - ) - - def dispatch_low_latency( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - num_max_dispatch_tokens_per_rank: int, - num_experts: int, - use_fp8: bool = False, - ): - """ - # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'. - # Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall. - # More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782 - - diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu - index 76ae2e2..8ecd08f 100644 - --- a/csrc/kernels/internode_ll.cu - +++ b/csrc/kernels/internode_ll.cu - @@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, - int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, - void* workspace, cudaStream_t stream, int phases) { - constexpr int kNumMaxTopK = 9; - - constexpr int kNumWarpsPerGroup = 10; - - constexpr int kNumWarpGroups = 3; - + constexpr int kNumWarpsPerGroup = 8; - + constexpr int kNumWarpGroups = 4; - EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections"); - - const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; - @@ -501,8 +501,8 @@ void combine(void* combined_x, - int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, - void* workspace, cudaStream_t stream, int phases) { - - constexpr int kNumWarpsPerGroup = 10; - - constexpr int kNumWarpGroups = 3; - + constexpr int kNumWarpsPerGroup = 8; - + constexpr int kNumWarpGroups = 4; - constexpr int kNumMaxTopk = 9; - - const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; - """ - - packed_recv_hidden, packed_recv_count, self.handle, event, hook = ( - self.buffer_low_latency.low_latency_dispatch( - hidden_states, - topk_idx, - num_max_dispatch_tokens_per_rank, - num_experts, - use_fp8=use_fp8, - async_finish=not self.return_recv_hook, - return_recv_hook=self.return_recv_hook, - ) - ) - return packed_recv_hidden, packed_recv_count, event, hook - def combine( self, hidden_states: torch.Tensor, From 4c37f2ab95b635aaf7896f83912e759ddac46227 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:21:44 +0800 Subject: [PATCH 26/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 64 +++++++++---------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 8152caa832d..99bfc99c42d 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -216,6 +216,18 @@ def _dispatch_normal( event, ) + def _combine_normal(self, x: torch.Tensor): + previous_event = Buffer.capture() if self.async_finish else None + + combined_x, _, event = self.buffer_normal.combine( + x, + self.handle, + async_finish=self.async_finish, + previous_event=previous_event, + allocate_on_comm_stream=previous_event is not None, + ) + return combined_x, event + class _DeepEPDispatcherLowLatency(_DeepEPDispatcherBase): def __init__(self, return_recv_hook: bool, **kwargs): @@ -289,6 +301,24 @@ def _dispatch_low_latency( ) return packed_recv_hidden, packed_recv_count, event, hook + def _combine_low_latency( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + ): + combined_hidden_states, event, hook = ( + self.buffer_low_latency.low_latency_combine( + hidden_states, + topk_idx, + topk_weights, + self.handle, + async_finish=not self.return_recv_hook, + return_recv_hook=self.return_recv_hook, + ) + ) + return combined_hidden_states, event, hook + class DeepEPDispatcher: def __init__( @@ -423,12 +453,12 @@ def combine( device=hidden_states.device, dtype=hidden_states.dtype, ) - hidden_states, event = self.combine_normal( + hidden_states, event = self._combine_normal( output, ) event.current_stream_wait() if self.async_finish else () elif resolved_deepep_mode == DeepEPMode.low_latency: - hidden_states, event, hook = self.combine_low_latency( + hidden_states, event, hook = self._combine_low_latency( hidden_states, topk_idx, topk_weights, @@ -438,33 +468,3 @@ def combine( raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") return hidden_states - - def combine_normal(self, x: torch.Tensor): - previous_event = Buffer.capture() if self.async_finish else None - - combined_x, _, event = self.buffer_normal.combine( - x, - self.handle, - async_finish=self.async_finish, - previous_event=previous_event, - allocate_on_comm_stream=previous_event is not None, - ) - return combined_x, event - - def combine_low_latency( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - ): - combined_hidden_states, event, hook = ( - self.buffer_low_latency.low_latency_combine( - hidden_states, - topk_idx, - topk_weights, - self.handle, - async_finish=not self.return_recv_hook, - return_recv_hook=self.return_recv_hook, - ) - ) - return combined_hidden_states, event, hook From 188d5a0d748469a1af10abb81cd1402b6622f0d8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:22:35 +0800 Subject: [PATCH 27/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 99bfc99c42d..381353489c2 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -369,12 +369,6 @@ def dispatch( forward_mode: ForwardMode = None, ) -> Tuple: topk_idx = topk_idx.to(torch.int64) - reorder_topk_ids = torch.empty( - (0,), device=hidden_states.device, dtype=torch.int64 - ) - seg_indptr = torch.zeros( - (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 - ) masked_m = torch.empty( (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 ) @@ -408,6 +402,14 @@ def dispatch( use_fp8=True, ) hook() if self.return_recv_hook else event.current_stream_wait() + + # TODO make it none + reorder_topk_ids = torch.empty( + (0,), device=hidden_states.device, dtype=torch.int64 + ) + seg_indptr = torch.zeros( + (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 + ) else: raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") From 82a80922f7fd47e78d281a6f8279841495e80f49 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:22:48 +0800 Subject: [PATCH 28/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 381353489c2..4507fe084f1 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -369,10 +369,6 @@ def dispatch( forward_mode: ForwardMode = None, ) -> Tuple: topk_idx = topk_idx.to(torch.int64) - masked_m = torch.empty( - (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 - ) - expected_m = 0 resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: @@ -387,6 +383,11 @@ def dispatch( reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( hidden_states, topk_idx, fp8_dtype=hidden_states.dtype ) + + masked_m = torch.empty( + (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 + ) + expected_m = 0 elif resolved_deepep_mode == DeepEPMode.low_latency: expected_m = ( hidden_states.shape[0] From 39d9fb229ff0419d9c899ea8cc6c012a523fb6a1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:23:12 +0800 Subject: [PATCH 29/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 4507fe084f1..61523805a25 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -384,10 +384,12 @@ def dispatch( hidden_states, topk_idx, fp8_dtype=hidden_states.dtype ) - masked_m = torch.empty( - (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 - ) - expected_m = 0 + # TODO + # masked_m = torch.empty( + # (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 + # ) + # expected_m = 0 + masked_m = expected_m = None elif resolved_deepep_mode == DeepEPMode.low_latency: expected_m = ( hidden_states.shape[0] @@ -404,13 +406,14 @@ def dispatch( ) hook() if self.return_recv_hook else event.current_stream_wait() - # TODO make it none - reorder_topk_ids = torch.empty( - (0,), device=hidden_states.device, dtype=torch.int64 - ) - seg_indptr = torch.zeros( - (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 - ) + # TODO + # reorder_topk_ids = torch.empty( + # (0,), device=hidden_states.device, dtype=torch.int64 + # ) + # seg_indptr = torch.zeros( + # (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 + # ) + reorder_topk_ids = seg_indptr = None else: raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") From 73189d0bef90130bc9edc3c7006474d1645920bd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:23:31 +0800 Subject: [PATCH 30/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 61523805a25..da4ffb4dfe4 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -368,10 +368,9 @@ def dispatch( num_max_dispatch_tokens_per_rank: int = 128, forward_mode: ForwardMode = None, ) -> Tuple: - topk_idx = topk_idx.to(torch.int64) - resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: + topk_idx = topk_idx.to(torch.int64) ( hidden_states, topk_idx, @@ -391,6 +390,7 @@ def dispatch( # expected_m = 0 masked_m = expected_m = None elif resolved_deepep_mode == DeepEPMode.low_latency: + topk_idx = topk_idx.to(torch.int64) expected_m = ( hidden_states.shape[0] * self.buffer_low_latency.group_size From fd39fdbb92579d61141bbd798b69e56bd39d83c7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:24:37 +0800 Subject: [PATCH 31/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index da4ffb4dfe4..1e614f8714e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -357,8 +357,6 @@ def __init__( **common_kwargs, ) - -class DeepEPDispatcher: def dispatch( self, hidden_states: torch.Tensor, From f7e79281785f2974ac755298cc9b446c61b0b4f7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:25:55 +0800 Subject: [PATCH 32/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 125 ++++++++++-------- 1 file changed, 72 insertions(+), 53 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 1e614f8714e..4164127660a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -163,6 +163,37 @@ def _deepep_permute( ) return reorder_topk_ids, seg_indptr, gateup_input + def dispatch(self): + topk_idx = topk_idx.to(torch.int64) + ( + hidden_states, + topk_idx, + topk_weights, + event, + ) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) + event.current_stream_wait() if self.async_finish else () + if hidden_states.shape[0] > 0: + reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( + hidden_states, topk_idx, fp8_dtype=hidden_states.dtype + ) + + # TODO + # masked_m = torch.empty( + # (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 + # ) + # expected_m = 0 + masked_m = expected_m = None + + return ( + hidden_states, + topk_idx, + topk_weights, + reorder_topk_ids, + seg_indptr, + masked_m, + expected_m, + ) + def _dispatch_normal( self, x: torch.Tensor, @@ -247,6 +278,42 @@ def __init__(self, return_recv_hook: bool, **kwargs): ) self.return_recv_hook = return_recv_hook + def dispatch(self): + topk_idx = topk_idx.to(torch.int64) + expected_m = ( + hidden_states.shape[0] + * self.buffer_low_latency.group_size + * topk_idx.shape[1] + + num_experts + ) // num_experts + hidden_states, masked_m, event, hook = self._dispatch_low_latency( + hidden_states, + topk_idx, + num_max_dispatch_tokens_per_rank, + num_experts, + use_fp8=True, + ) + hook() if self.return_recv_hook else event.current_stream_wait() + + # TODO + # reorder_topk_ids = torch.empty( + # (0,), device=hidden_states.device, dtype=torch.int64 + # ) + # seg_indptr = torch.zeros( + # (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 + # ) + reorder_topk_ids = seg_indptr = None + + return ( + hidden_states, + topk_idx, + topk_weights, + reorder_topk_ids, + seg_indptr, + masked_m, + expected_m, + ) + def _dispatch_low_latency( self, hidden_states: torch.Tensor, @@ -366,65 +433,17 @@ def dispatch( num_max_dispatch_tokens_per_rank: int = 128, forward_mode: ForwardMode = None, ) -> Tuple: + self._get_dispatcher(forward_mode).dispatch(TODO) + + def _get_dispatcher(self, forward_mode: ForwardMode): resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: - topk_idx = topk_idx.to(torch.int64) - ( - hidden_states, - topk_idx, - topk_weights, - event, - ) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) - event.current_stream_wait() if self.async_finish else () - if hidden_states.shape[0] > 0: - reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( - hidden_states, topk_idx, fp8_dtype=hidden_states.dtype - ) - - # TODO - # masked_m = torch.empty( - # (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 - # ) - # expected_m = 0 - masked_m = expected_m = None + return self._normal_dispatcher elif resolved_deepep_mode == DeepEPMode.low_latency: - topk_idx = topk_idx.to(torch.int64) - expected_m = ( - hidden_states.shape[0] - * self.buffer_low_latency.group_size - * topk_idx.shape[1] - + num_experts - ) // num_experts - hidden_states, masked_m, event, hook = self._dispatch_low_latency( - hidden_states, - topk_idx, - num_max_dispatch_tokens_per_rank, - num_experts, - use_fp8=True, - ) - hook() if self.return_recv_hook else event.current_stream_wait() - - # TODO - # reorder_topk_ids = torch.empty( - # (0,), device=hidden_states.device, dtype=torch.int64 - # ) - # seg_indptr = torch.zeros( - # (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 - # ) - reorder_topk_ids = seg_indptr = None + return self._low_latency_dispatcher else: raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") - return ( - hidden_states, - topk_idx, - topk_weights, - reorder_topk_ids, - seg_indptr, - masked_m, - expected_m, - ) - def combine( self, hidden_states: torch.Tensor, From f9dd1e42735a150117c1f63b8020637db93c8af3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:26:45 +0800 Subject: [PATCH 33/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 4164127660a..06739060a2e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -163,7 +163,14 @@ def _deepep_permute( ) return reorder_topk_ids, seg_indptr, gateup_input - def dispatch(self): + def dispatch( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + num_max_dispatch_tokens_per_rank: int, + ): topk_idx = topk_idx.to(torch.int64) ( hidden_states, @@ -278,7 +285,14 @@ def __init__(self, return_recv_hook: bool, **kwargs): ) self.return_recv_hook = return_recv_hook - def dispatch(self): + def dispatch( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + num_max_dispatch_tokens_per_rank: int, + ): topk_idx = topk_idx.to(torch.int64) expected_m = ( hidden_states.shape[0] From 8d5720d613dc48681c2562400637a2fbee065887 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:27:07 +0800 Subject: [PATCH 34/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 06739060a2e..e9e3f178a0e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -447,7 +447,7 @@ def dispatch( num_max_dispatch_tokens_per_rank: int = 128, forward_mode: ForwardMode = None, ) -> Tuple: - self._get_dispatcher(forward_mode).dispatch(TODO) + return self._get_dispatcher(forward_mode).dispatch(TODO) def _get_dispatcher(self, forward_mode: ForwardMode): resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) From 5bb5dc6695d1d6169d9dfd3fc3946a90759f0590 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:27:38 +0800 Subject: [PATCH 35/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 49 ++----------------- 1 file changed, 5 insertions(+), 44 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index e9e3f178a0e..9263541dc23 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -449,15 +449,6 @@ def dispatch( ) -> Tuple: return self._get_dispatcher(forward_mode).dispatch(TODO) - def _get_dispatcher(self, forward_mode: ForwardMode): - resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) - if resolved_deepep_mode == DeepEPMode.normal: - return self._normal_dispatcher - elif resolved_deepep_mode == DeepEPMode.low_latency: - return self._low_latency_dispatcher - else: - raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") - def combine( self, hidden_states: torch.Tensor, @@ -465,43 +456,13 @@ def combine( topk_weights: torch.Tensor, forward_mode: ForwardMode, ) -> torch.Tensor: + return self._get_dispatcher(forward_mode).combine(TODO) + + def _get_dispatcher(self, forward_mode: ForwardMode): resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: - if hidden_states.shape[0] > 0: - num_tokens = self.src2dst.shape[0] // self.router_topk - output = torch.empty( - (num_tokens, hidden_states.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - deepep_post_reorder_triton_kernel[(num_tokens,)]( - hidden_states, - output, - self.src2dst, - topk_idx, - topk_weights, - self.router_topk, - hidden_states.shape[1], - BLOCK_SIZE=512, - ) - else: - output = torch.zeros( - (0, hidden_states.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - hidden_states, event = self._combine_normal( - output, - ) - event.current_stream_wait() if self.async_finish else () + return self._normal_dispatcher elif resolved_deepep_mode == DeepEPMode.low_latency: - hidden_states, event, hook = self._combine_low_latency( - hidden_states, - topk_idx, - topk_weights, - ) - hook() if self.return_recv_hook else event.current_stream_wait() + return self._low_latency_dispatcher else: raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") - - return hidden_states From 0c32670791c89b3c605be94cf7f02a55dc64ec75 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:28:06 +0800 Subject: [PATCH 36/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 9263541dc23..634f27d1fa3 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -254,6 +254,43 @@ def _dispatch_normal( event, ) + def combine( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_mode: ForwardMode, + ) -> torch.Tensor: + if hidden_states.shape[0] > 0: + num_tokens = self.src2dst.shape[0] // self.router_topk + output = torch.empty( + (num_tokens, hidden_states.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + deepep_post_reorder_triton_kernel[(num_tokens,)]( + hidden_states, + output, + self.src2dst, + topk_idx, + topk_weights, + self.router_topk, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + else: + output = torch.zeros( + (0, hidden_states.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + hidden_states, event = self._combine_normal( + output, + ) + event.current_stream_wait() if self.async_finish else () + + return hidden_states + def _combine_normal(self, x: torch.Tensor): previous_event = Buffer.capture() if self.async_finish else None @@ -382,6 +419,22 @@ def _dispatch_low_latency( ) return packed_recv_hidden, packed_recv_count, event, hook + def combine( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_mode: ForwardMode, + ) -> torch.Tensor: + hidden_states, event, hook = self._combine_low_latency( + hidden_states, + topk_idx, + topk_weights, + ) + hook() if self.return_recv_hook else event.current_stream_wait() + + return hidden_states + def _combine_low_latency( self, hidden_states: torch.Tensor, From b1123115109af77a67e431fb2f07aec1b9d9bcbf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:28:17 +0800 Subject: [PATCH 37/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 634f27d1fa3..500a2317fc0 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -259,7 +259,6 @@ def combine( hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - forward_mode: ForwardMode, ) -> torch.Tensor: if hidden_states.shape[0] > 0: num_tokens = self.src2dst.shape[0] // self.router_topk @@ -424,7 +423,6 @@ def combine( hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - forward_mode: ForwardMode, ) -> torch.Tensor: hidden_states, event, hook = self._combine_low_latency( hidden_states, From 063a936517ac9858f216120a9175bac1361402f6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:28:52 +0800 Subject: [PATCH 38/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 78 +++++++++---------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 500a2317fc0..ffa580dfbfa 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -124,45 +124,6 @@ def __init__(self, async_finish: bool, **kwargs): self.async_finish = async_finish self.src2dst = None - """ - Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py - """ - - def _deepep_permute( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - fp8_dtype: Optional[torch.dtype] = None, - use_fp8_w8a8: bool = False, - use_block_quant: bool = False, - ): - reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( - topk_idx, self.num_experts - ) - num_total_tokens = reorder_topk_ids.numel() - gateup_input = torch.empty( - (int(num_total_tokens), hidden_states.shape[1]), - device=hidden_states.device, - dtype=( - fp8_dtype - if (use_fp8_w8a8 and not use_block_quant) - else hidden_states.dtype - ), - ) - # PreReorder - deepep_permute_triton_kernel[(hidden_states.shape[0],)]( - hidden_states, - gateup_input, - self.src2dst, - topk_idx, - None, - self.router_topk, - hidden_states.shape[1], - BLOCK_SIZE=512, - ) - return reorder_topk_ids, seg_indptr, gateup_input - def dispatch( self, hidden_states: torch.Tensor, @@ -254,6 +215,45 @@ def _dispatch_normal( event, ) + """ + Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py + """ + + def _deepep_permute( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + fp8_dtype: Optional[torch.dtype] = None, + use_fp8_w8a8: bool = False, + use_block_quant: bool = False, + ): + reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( + topk_idx, self.num_experts + ) + num_total_tokens = reorder_topk_ids.numel() + gateup_input = torch.empty( + (int(num_total_tokens), hidden_states.shape[1]), + device=hidden_states.device, + dtype=( + fp8_dtype + if (use_fp8_w8a8 and not use_block_quant) + else hidden_states.dtype + ), + ) + # PreReorder + deepep_permute_triton_kernel[(hidden_states.shape[0],)]( + hidden_states, + gateup_input, + self.src2dst, + topk_idx, + None, + self.router_topk, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + return reorder_topk_ids, seg_indptr, gateup_input + def combine( self, hidden_states: torch.Tensor, From fbedeece717ad7079058d1973a11a872723e09bb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:29:27 +0800 Subject: [PATCH 39/65] more --- .../sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index ffa580dfbfa..7ec7dc7cf54 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -215,11 +215,6 @@ def _dispatch_normal( event, ) - """ - Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py - """ - def _deepep_permute( self, hidden_states: torch.Tensor, @@ -228,6 +223,11 @@ def _deepep_permute( use_fp8_w8a8: bool = False, use_block_quant: bool = False, ): + """ + Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py + """ + reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( topk_idx, self.num_experts ) From 6b1b5630ad9d1718bf36fc4a7e1ec84b1de4cdf5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:30:32 +0800 Subject: [PATCH 40/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 7ec7dc7cf54..e0695f650f7 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -498,7 +498,13 @@ def dispatch( num_max_dispatch_tokens_per_rank: int = 128, forward_mode: ForwardMode = None, ) -> Tuple: - return self._get_dispatcher(forward_mode).dispatch(TODO) + return self._get_dispatcher(forward_mode).dispatch( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + num_experts=num_experts, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, + ) def combine( self, @@ -507,7 +513,11 @@ def combine( topk_weights: torch.Tensor, forward_mode: ForwardMode, ) -> torch.Tensor: - return self._get_dispatcher(forward_mode).combine(TODO) + return self._get_dispatcher(forward_mode).combine( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + ) def _get_dispatcher(self, forward_mode: ForwardMode): resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) From acedeb91401130a25a1628a335849699a404a7be Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:30:59 +0800 Subject: [PATCH 41/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index e0695f650f7..d3c49eaf89e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -113,6 +113,24 @@ def __init__( self.handle = None + def dispatch( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + num_max_dispatch_tokens_per_rank: int, + ): + raise NotImplementedError + + def combine( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + class _DeepEPDispatcherNormal(_DeepEPDispatcherBase): def __init__(self, async_finish: bool, **kwargs): From ad2fe7a92e1d7e80f5ce4c586f38b2dec5a23f0d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:31:09 +0800 Subject: [PATCH 42/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index d3c49eaf89e..bd2ec8d09fb 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -537,7 +537,7 @@ def combine( topk_weights=topk_weights, ) - def _get_dispatcher(self, forward_mode: ForwardMode): + def _get_dispatcher(self, forward_mode: ForwardMode) -> '_DeepEPDispatcherBase': resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: return self._normal_dispatcher From 2383c01c167c3a92a071785c30c8efae25fca4f0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:31:28 +0800 Subject: [PATCH 43/65] more --- .../sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index bd2ec8d09fb..9fe9c88e999 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -85,7 +85,7 @@ def get_buffer_low_latency( return _buffer_low_latency -class _DeepEPDispatcherBase: +class _DeepEPDispatcherImplBase: def __init__( self, group: torch.distributed.ProcessGroup, @@ -132,7 +132,7 @@ def combine( raise NotImplementedError -class _DeepEPDispatcherNormal(_DeepEPDispatcherBase): +class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): def __init__(self, async_finish: bool, **kwargs): super().__init__(**kwargs) @@ -321,7 +321,7 @@ def _combine_normal(self, x: torch.Tensor): return combined_x, event -class _DeepEPDispatcherLowLatency(_DeepEPDispatcherBase): +class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): def __init__(self, return_recv_hook: bool, **kwargs): super().__init__(**kwargs) @@ -497,12 +497,12 @@ def __init__( ) if self.deepep_mode.enable_normal(): - self._normal_dispatcher = _DeepEPDispatcherNormal( + self._normal_dispatcher = _DeepEPDispatcherImplNormal( async_finish=async_finish, **common_kwargs, ) if self.deepep_mode.enable_low_latency(): - self._low_latency_dispatcher = _DeepEPDispatcherLowLatency( + self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency( return_recv_hook=return_recv_hook, **common_kwargs, ) @@ -537,7 +537,7 @@ def combine( topk_weights=topk_weights, ) - def _get_dispatcher(self, forward_mode: ForwardMode) -> '_DeepEPDispatcherBase': + def _get_dispatcher(self, forward_mode: ForwardMode) -> '_DeepEPDispatcherImplBase': resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: return self._normal_dispatcher From 0050d66ad0f564548267e356836c42428b3bc92d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:32:07 +0800 Subject: [PATCH 44/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 9fe9c88e999..6ff853fd22a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -23,7 +23,7 @@ _buffer_low_latency = None -def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int): +def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int): """ Copy from DeepEP example usage in model inference prefilling. https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling @@ -53,7 +53,7 @@ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int): return _buffer_normal -def get_buffer_low_latency( +def _get_buffer_low_latency( group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, @@ -136,7 +136,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): def __init__(self, async_finish: bool, **kwargs): super().__init__(**kwargs) - self.buffer_normal = get_buffer_normal( + self.buffer_normal = _get_buffer_normal( self.group, self.hidden_size * self.params_bytes ) self.async_finish = async_finish @@ -331,7 +331,7 @@ def __init__(self, return_recv_hook: bool, **kwargs): """ # TODO(ch-wan): allow users to set this value self.num_max_dispatch_tokens_per_rank = 128 - self.buffer_low_latency = get_buffer_low_latency( + self.buffer_low_latency = _get_buffer_low_latency( self.group, self.num_max_dispatch_tokens_per_rank, self.hidden_size, From 7b8c5a24d837c26bfae8dd4a07ea3213eb09d7d1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:33:06 +0800 Subject: [PATCH 45/65] fmt --- .../sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 12 ++++++------ python/sglang/srt/models/deepseek_v2.py | 7 +++++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 6ff853fd22a..f7a2ec1d9a9 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -349,11 +349,11 @@ def dispatch( ): topk_idx = topk_idx.to(torch.int64) expected_m = ( - hidden_states.shape[0] - * self.buffer_low_latency.group_size - * topk_idx.shape[1] - + num_experts - ) // num_experts + hidden_states.shape[0] + * self.buffer_low_latency.group_size + * topk_idx.shape[1] + + num_experts + ) // num_experts hidden_states, masked_m, event, hook = self._dispatch_low_latency( hidden_states, topk_idx, @@ -537,7 +537,7 @@ def combine( topk_weights=topk_weights, ) - def _get_dispatcher(self, forward_mode: ForwardMode) -> '_DeepEPDispatcherImplBase': + def _get_dispatcher(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase": resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: return self._normal_dispatcher diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index edc845162cf..2f78de4925a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -200,8 +200,11 @@ def __init__( topk_group=config.topk_group, correction_bias=self.gate.e_score_correction_bias, prefix=add_prefix("experts", prefix), - **(dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) - if global_server_args_dict["enable_deepep_moe"] else {}), + **( + dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) + if global_server_args_dict["enable_deepep_moe"] + else {} + ), ) if config.n_shared_experts is not None: From 449a78b87a7351de3dff6c013bf809a96e5aff2b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:38:11 +0800 Subject: [PATCH 46/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 33 ++++++++++++++++--- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index f7a2ec1d9a9..5a92026804e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -349,11 +349,11 @@ def dispatch( ): topk_idx = topk_idx.to(torch.int64) expected_m = ( - hidden_states.shape[0] - * self.buffer_low_latency.group_size - * topk_idx.shape[1] - + num_experts - ) // num_experts + hidden_states.shape[0] + * self.buffer_low_latency.group_size + * topk_idx.shape[1] + + num_experts + ) // num_experts hidden_states, masked_m, event, hook = self._dispatch_low_latency( hidden_states, topk_idx, @@ -524,6 +524,29 @@ def dispatch( num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, ) + def dispatch_a( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + num_max_dispatch_tokens_per_rank: int = 128, + forward_mode: ForwardMode = None, + ): + inner_state = self._get_dispatcher(forward_mode).dispatch_a( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + num_experts=num_experts, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, + ) + self._dispatch_intermediate_state = forward_mode, inner_state + + def dispatch_b(self): + forward_mode, inner_state = self._dispatch_intermediate_state + del self._dispatch_intermediate_state + return self._get_dispatcher(forward_mode).dispatch_b() + def combine( self, hidden_states: torch.Tensor, From aefcfc9b7529bebb3068df63a876cc9a9431db92 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:38:25 +0800 Subject: [PATCH 47/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 5a92026804e..e01b3f0ec8d 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -545,7 +545,7 @@ def dispatch_a( def dispatch_b(self): forward_mode, inner_state = self._dispatch_intermediate_state del self._dispatch_intermediate_state - return self._get_dispatcher(forward_mode).dispatch_b() + return self._get_dispatcher(forward_mode).dispatch_b(*inner_state) def combine( self, From 797008f253988549a0abee4863bc04d30107bc56 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:38:52 +0800 Subject: [PATCH 48/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index e01b3f0ec8d..86a242f8adb 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -516,13 +516,15 @@ def dispatch( num_max_dispatch_tokens_per_rank: int = 128, forward_mode: ForwardMode = None, ) -> Tuple: - return self._get_dispatcher(forward_mode).dispatch( + self.dispatch_a( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, num_experts=num_experts, num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, + forward_mode=forward_mode, ) + return self.dispatch_b() def dispatch_a( self, From d722607415d746ae522bdb5a142189733080d231 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:39:48 +0800 Subject: [PATCH 49/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 86a242f8adb..01653252b2f 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -562,6 +562,25 @@ def combine( topk_weights=topk_weights, ) + def combine_a( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_mode: ForwardMode, + ) -> torch.Tensor: + inner_state = self._get_dispatcher(forward_mode).combine_a( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + ) + self._combine_intermediate_state = forward_mode, inner_state + + def combine_b(self): + forward_mode, inner_state = self._combine_intermediate_state + del self._combine_intermediate_state + return self._get_dispatcher(forward_mode).combine_b(*inner_state) + def _get_dispatcher(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase": resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: From 67ef8709c820e32713c9cba84c209da951afd653 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:40:22 +0800 Subject: [PATCH 50/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 34 +++---------------- 1 file changed, 5 insertions(+), 29 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 01653252b2f..873d7e377ee 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -507,23 +507,8 @@ def __init__( **common_kwargs, ) - def dispatch( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - num_experts: int, - num_max_dispatch_tokens_per_rank: int = 128, - forward_mode: ForwardMode = None, - ) -> Tuple: - self.dispatch_a( - hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - num_experts=num_experts, - num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, - forward_mode=forward_mode, - ) + def dispatch(self, *args, **kwargs) -> Tuple: + self.dispatch_a(*args, **kwargs) return self.dispatch_b() def dispatch_a( @@ -549,18 +534,9 @@ def dispatch_b(self): del self._dispatch_intermediate_state return self._get_dispatcher(forward_mode).dispatch_b(*inner_state) - def combine( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - forward_mode: ForwardMode, - ) -> torch.Tensor: - return self._get_dispatcher(forward_mode).combine( - hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - ) + def combine(self, *args, **kwargs) -> Tuple: + self.combine_a(*args, **kwargs) + return self.combine_b() def combine_a( self, From 58952e6950f2dfabb11b222c46309cea464806fd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:41:06 +0800 Subject: [PATCH 51/65] more --- .../sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 873d7e377ee..fac842f7fce 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -113,7 +113,7 @@ def __init__( self.handle = None - def dispatch( + def dispatch_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, @@ -123,12 +123,18 @@ def dispatch( ): raise NotImplementedError - def combine( + def dispatch_b(self, *args, **kwargs): + raise NotImplementedError + + def combine_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - ) -> torch.Tensor: + ): + raise NotImplementedError + + def combine_b(self, *args, **kwargs): raise NotImplementedError From 967c95fda09c5aeb7547035f7726bb04c2e82d09 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:41:28 +0800 Subject: [PATCH 52/65] mv --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index fac842f7fce..b45b3dbd137 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -157,12 +157,13 @@ def dispatch( num_max_dispatch_tokens_per_rank: int, ): topk_idx = topk_idx.to(torch.int64) + previous_event = Buffer.capture() if self.async_finish else None ( hidden_states, topk_idx, topk_weights, event, - ) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) + ) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts, previous_event) event.current_stream_wait() if self.async_finish else () if hidden_states.shape[0] > 0: reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( @@ -192,9 +193,8 @@ def _dispatch_normal( topk_idx: torch.Tensor, topk_weights: torch.Tensor, num_experts: int, + previous_event, ): - previous_event = Buffer.capture() if self.async_finish else None - ( num_tokens_per_rank, num_tokens_per_rdma_rank, From 3e022e679dcb42abbb4597f176204f18a10cda3a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:42:05 +0800 Subject: [PATCH 53/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index b45b3dbd137..56fc546ba11 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -148,7 +148,7 @@ def __init__(self, async_finish: bool, **kwargs): self.async_finish = async_finish self.src2dst = None - def dispatch( + def dispatch_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, @@ -158,6 +158,9 @@ def dispatch( ): topk_idx = topk_idx.to(torch.int64) previous_event = Buffer.capture() if self.async_finish else None + return hidden_states, topk_idx, topk_weights, num_experts, previous_event + + def dispatch_b(self, hidden_states, topk_idx, topk_weights, num_experts, previous_event): ( hidden_states, topk_idx, From fe78450e97f962fef380dc598695e30615fdf587 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:42:53 +0800 Subject: [PATCH 54/65] mv --- .../sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 56fc546ba11..29cc5fc36e6 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -310,16 +310,14 @@ def combine( device=hidden_states.device, dtype=hidden_states.dtype, ) - hidden_states, event = self._combine_normal( - output, - ) + previous_event = Buffer.capture() if self.async_finish else None + + hidden_states, event = self._combine_normal(output, previous_event) event.current_stream_wait() if self.async_finish else () return hidden_states - def _combine_normal(self, x: torch.Tensor): - previous_event = Buffer.capture() if self.async_finish else None - + def _combine_normal(self, x: torch.Tensor, previous_event): combined_x, _, event = self.buffer_normal.combine( x, self.handle, From d58ffc80bee8ba6781341b87569ec27f52e61c3a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:43:32 +0800 Subject: [PATCH 55/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 29cc5fc36e6..434e30db9aa 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -281,12 +281,12 @@ def _deepep_permute( ) return reorder_topk_ids, seg_indptr, gateup_input - def combine( + def combine_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - ) -> torch.Tensor: + ): if hidden_states.shape[0] > 0: num_tokens = self.src2dst.shape[0] // self.router_topk output = torch.empty( @@ -311,10 +311,11 @@ def combine( dtype=hidden_states.dtype, ) previous_event = Buffer.capture() if self.async_finish else None + return output, previous_event + def combine_b(self, output, previous_event): hidden_states, event = self._combine_normal(output, previous_event) event.current_stream_wait() if self.async_finish else () - return hidden_states def _combine_normal(self, x: torch.Tensor, previous_event): From 3ed173276bd4ff32cc41ebe264420d07d6c39ca6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:44:29 +0800 Subject: [PATCH 56/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 434e30db9aa..c93d9ee1ed0 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -552,7 +552,7 @@ def combine_a( topk_idx: torch.Tensor, topk_weights: torch.Tensor, forward_mode: ForwardMode, - ) -> torch.Tensor: + ): inner_state = self._get_dispatcher(forward_mode).combine_a( hidden_states=hidden_states, topk_idx=topk_idx, From e26c86043587c4e6db4a7da956255788883a3edf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:44:49 +0800 Subject: [PATCH 57/65] more --- .../sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index c93d9ee1ed0..a7cccd305c9 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -528,7 +528,7 @@ def dispatch_a( num_max_dispatch_tokens_per_rank: int = 128, forward_mode: ForwardMode = None, ): - inner_state = self._get_dispatcher(forward_mode).dispatch_a( + inner_state = self._get_impl(forward_mode).dispatch_a( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, @@ -540,7 +540,7 @@ def dispatch_a( def dispatch_b(self): forward_mode, inner_state = self._dispatch_intermediate_state del self._dispatch_intermediate_state - return self._get_dispatcher(forward_mode).dispatch_b(*inner_state) + return self._get_impl(forward_mode).dispatch_b(*inner_state) def combine(self, *args, **kwargs) -> Tuple: self.combine_a(*args, **kwargs) @@ -553,7 +553,7 @@ def combine_a( topk_weights: torch.Tensor, forward_mode: ForwardMode, ): - inner_state = self._get_dispatcher(forward_mode).combine_a( + inner_state = self._get_impl(forward_mode).combine_a( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, @@ -563,9 +563,9 @@ def combine_a( def combine_b(self): forward_mode, inner_state = self._combine_intermediate_state del self._combine_intermediate_state - return self._get_dispatcher(forward_mode).combine_b(*inner_state) + return self._get_impl(forward_mode).combine_b(*inner_state) - def _get_dispatcher(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase": + def _get_impl(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase": resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: return self._normal_dispatcher From 26296739724e528d682b704496e36496e91e5f30 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:46:15 +0800 Subject: [PATCH 58/65] rename --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index a7cccd305c9..bdb63d89845 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -166,7 +166,7 @@ def dispatch_b(self, hidden_states, topk_idx, topk_weights, num_experts, previou topk_idx, topk_weights, event, - ) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts, previous_event) + ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, num_experts, previous_event) event.current_stream_wait() if self.async_finish else () if hidden_states.shape[0] > 0: reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( @@ -190,7 +190,7 @@ def dispatch_b(self, hidden_states, topk_idx, topk_weights, num_experts, previou expected_m, ) - def _dispatch_normal( + def _dispatch_core( self, x: torch.Tensor, topk_idx: torch.Tensor, @@ -314,11 +314,11 @@ def combine_a( return output, previous_event def combine_b(self, output, previous_event): - hidden_states, event = self._combine_normal(output, previous_event) + hidden_states, event = self._combine_core(output, previous_event) event.current_stream_wait() if self.async_finish else () return hidden_states - def _combine_normal(self, x: torch.Tensor, previous_event): + def _combine_core(self, x: torch.Tensor, previous_event): combined_x, _, event = self.buffer_normal.combine( x, self.handle, @@ -362,7 +362,7 @@ def dispatch( * topk_idx.shape[1] + num_experts ) // num_experts - hidden_states, masked_m, event, hook = self._dispatch_low_latency( + hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, @@ -390,7 +390,7 @@ def dispatch( expected_m, ) - def _dispatch_low_latency( + def _dispatch_core( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, @@ -450,7 +450,7 @@ def combine( topk_idx: torch.Tensor, topk_weights: torch.Tensor, ) -> torch.Tensor: - hidden_states, event, hook = self._combine_low_latency( + hidden_states, event, hook = self._combine_core( hidden_states, topk_idx, topk_weights, @@ -459,7 +459,7 @@ def combine( return hidden_states - def _combine_low_latency( + def _combine_core( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, From f3b4209d0962459ff65a3676110ec21aaa72547a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:47:00 +0800 Subject: [PATCH 59/65] more --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index bdb63d89845..76d9f5a92eb 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -347,7 +347,7 @@ def __init__(self, return_recv_hook: bool, **kwargs): ) self.return_recv_hook = return_recv_hook - def dispatch( + def dispatch_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, @@ -369,6 +369,26 @@ def dispatch( num_experts, use_fp8=True, ) + return ( + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + hook, + event, + ) + + def dispatch_b( + self, + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + hook, + event, + ): hook() if self.return_recv_hook else event.current_stream_wait() # TODO From 5afd1b3f58de0bd98b3211e7b57fe0032fc7397c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:47:53 +0800 Subject: [PATCH 60/65] more --- .../sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 76d9f5a92eb..b49821a89e2 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -375,8 +375,8 @@ def dispatch_a( topk_weights, masked_m, expected_m, - hook, event, + hook, ) def dispatch_b( @@ -386,8 +386,8 @@ def dispatch_b( topk_weights, masked_m, expected_m, - hook, event, + hook, ): hook() if self.return_recv_hook else event.current_stream_wait() @@ -464,19 +464,21 @@ def _dispatch_core( ) return packed_recv_hidden, packed_recv_count, event, hook - def combine( + def combine_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - ) -> torch.Tensor: + ): hidden_states, event, hook = self._combine_core( hidden_states, topk_idx, topk_weights, ) - hook() if self.return_recv_hook else event.current_stream_wait() + return hidden_states, event, hook + def combine_b(self, hidden_states, event, hook): + hook() if self.return_recv_hook else event.current_stream_wait() return hidden_states def _combine_core( From eeb7c592e8810da1fb325b15b1697f21bd092955 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:51:24 +0800 Subject: [PATCH 61/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index b49821a89e2..95b0e935cdb 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -316,6 +316,8 @@ def combine_a( def combine_b(self, output, previous_event): hidden_states, event = self._combine_core(output, previous_event) event.current_stream_wait() if self.async_finish else () + del self.handle + del self.src2dst return hidden_states def _combine_core(self, x: torch.Tensor, previous_event): @@ -497,6 +499,7 @@ def _combine_core( return_recv_hook=self.return_recv_hook, ) ) + del self.handle return combined_hidden_states, event, hook From 3c63e7cbf4d43494d297c58cda5743173d5c2d6b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 2 Apr 2025 15:56:37 +0800 Subject: [PATCH 62/65] fmt --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index b49821a89e2..bd4ad819ac6 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -160,13 +160,17 @@ def dispatch_a( previous_event = Buffer.capture() if self.async_finish else None return hidden_states, topk_idx, topk_weights, num_experts, previous_event - def dispatch_b(self, hidden_states, topk_idx, topk_weights, num_experts, previous_event): + def dispatch_b( + self, hidden_states, topk_idx, topk_weights, num_experts, previous_event + ): ( hidden_states, topk_idx, topk_weights, event, - ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, num_experts, previous_event) + ) = self._dispatch_core( + hidden_states, topk_idx, topk_weights, num_experts, previous_event + ) event.current_stream_wait() if self.async_finish else () if hidden_states.shape[0] > 0: reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( @@ -357,11 +361,11 @@ def dispatch_a( ): topk_idx = topk_idx.to(torch.int64) expected_m = ( - hidden_states.shape[0] - * self.buffer_low_latency.group_size - * topk_idx.shape[1] - + num_experts - ) // num_experts + hidden_states.shape[0] + * self.buffer_low_latency.group_size + * topk_idx.shape[1] + + num_experts + ) // num_experts hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, topk_idx, From f92ee0ab5f883b61303856c4e4e9eb01663f8916 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 3 Apr 2025 10:51:03 +0800 Subject: [PATCH 63/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 5594cb787f6..d5b49b55a94 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -320,8 +320,8 @@ def combine_a( def combine_b(self, output, previous_event): hidden_states, event = self._combine_core(output, previous_event) event.current_stream_wait() if self.async_finish else () - del self.handle - del self.src2dst + self.handle = None + self.src2dst = None return hidden_states def _combine_core(self, x: torch.Tensor, previous_event): @@ -503,7 +503,7 @@ def _combine_core( return_recv_hook=self.return_recv_hook, ) ) - del self.handle + self.handle = None return combined_hidden_states, event, hook From b63d0ef62d684168003d848d2a6eeea6b1ab1e58 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 3 Apr 2025 10:51:49 +0800 Subject: [PATCH 64/65] more --- python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index f7a2ec1d9a9..04d8d7d2979 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -162,6 +162,13 @@ def dispatch( reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( hidden_states, topk_idx, fp8_dtype=hidden_states.dtype ) + else: + reorder_topk_ids = torch.empty( + (0,), device=hidden_states.device, dtype=torch.int64 + ) + seg_indptr = torch.zeros( + (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 + ) # TODO # masked_m = torch.empty( From af182ec7fcd9313e02522ea7cf6ef8179d8cc9d6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 3 Apr 2025 13:50:09 +0800 Subject: [PATCH 65/65] rm --- .../sglang/srt/layers/moe/ep_moe/token_dispatcher.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 41af0d7934e..100fa57fb2b 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -184,11 +184,6 @@ def dispatch_b( (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 ) - # TODO - # masked_m = torch.empty( - # (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 - # ) - # expected_m = 0 masked_m = expected_m = None return ( @@ -404,13 +399,6 @@ def dispatch_b( ): hook() if self.return_recv_hook else event.current_stream_wait() - # TODO - # reorder_topk_ids = torch.empty( - # (0,), device=hidden_states.device, dtype=torch.int64 - # ) - # seg_indptr = torch.zeros( - # (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 - # ) reorder_topk_ids = seg_indptr = None return (